{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"12_Embeddings","provenance":[{"file_id":"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/12_Embeddings.ipynb","timestamp":1608155949299},{"file_id":"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/12_Embeddings.ipynb","timestamp":1584545838024},{"file_id":"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/12_Embeddings.ipynb","timestamp":1583247939463}],"collapsed_sections":[],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"K-becVUr5_Q9"},"source":["<div align=\"center\">\n","<h1><img width=\"30\" src=\"../res/maiwei.png\">&nbsp;<a href=\"https://github.com/Charmve/computer-vision-in-action\">Coumpter Vision in Action</a></h1>\n","Applied ML · MLOps · Production\n","<br>\n","Join 20K+ developers in learning how to responsibly <a href=\"https://madewithml.com/about/\">deliver value</a> with ML.\n","</div>\n","\n","<br>\n","\n","<div align=\"center\">\n","    <a target=\"_blank\" href=\"https://newsletter.madewithml.com\"><img src=\"https://img.shields.io/badge/Subscribe-20K-brightgreen\"></a>&nbsp;\n","    <a target=\"_blank\" href=\"https://github.com/GokuMohandas/MadeWithML\"><img src=\"https://img.shields.io/github/stars/GokuMohandas/MadeWithML.svg?style=social&label=Star\"></a>&nbsp;\n","    <a target=\"_blank\" href=\"https://www.linkedin.com/in/goku\"><img src=\"https://img.shields.io/badge/style--5eba00.svg?label=LinkedIn&logo=linkedin&style=social\"></a>&nbsp;\n","    <a target=\"_blank\" href=\"https://twitter.com/GokuMohandas\"><img src=\"https://img.shields.io/twitter/follow/GokuMohandas.svg?label=Follow&style=social\"></a>\n","    <p>🔥&nbsp; Among the <a href=\"https://github.com/topics/deep-learning\" target=\"_blank\">top ML</a> repositories on GitHub</p>\n","</div>\n","\n","<br>\n","<hr>"]},{"cell_type":"markdown","metadata":{"id":"eTdCMVl9YAXw"},"source":["# Embeddings\n","\n","In this lesson, we will motivate the need for embeddings, which are capable of capturing the contextual, semantic and syntactic meaning in data."]},{"cell_type":"markdown","metadata":{"id":"xuabAj4PYj57"},"source":["<div align=\"left\">\n","<a target=\"_blank\" href=\"https://madewithml.com/courses/ml-foundations/embeddings/\"><img src=\"https://img.shields.io/badge/📖 Read-blog post-9cf\"></a>&nbsp;\n","<a href=\"https://github.com/GokuMohandas/MadeWithML/blob/main/notebooks/12_Embeddings.ipynb\" role=\"button\"><img src=\"https://img.shields.io/static/v1?label=&amp;message=View%20On%20GitHub&amp;color=586069&amp;logo=github&amp;labelColor=2f363d\"></a>&nbsp;\n","<a href=\"https://colab.research.google.com/github/GokuMohandas/MadeWithML/blob/main/notebooks/12_Embeddings.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n","</div>"]},{"cell_type":"markdown","metadata":{"id":"N9lh8_YvoR50"},"source":["# Overview\n"]},{"cell_type":"markdown","metadata":{"id":"Dtu7IU66obsh"},"source":["While one-hot encoding allows us to preserve the structural information, it does poses two major disadvantages. \n","\n","- linearly dependent on the number of unique tokens in our vocabulary, which is a problem if we're dealing with a large corpus.\n","- representation for each token does not preserve any relationship with respect to other tokens.\n","\n","In this notebook, we're going to motivate the need for embeddings and how they address all the shortcomings of one-hot encoding. The main idea of embeddings is to have fixed length representations for the tokens in a text regardless of the number of tokens in the vocabulary. With one-hot encoding, each token is represented by an array of size [1 X `vocab_size`], but with embeddings, each token now has the shape [1 X `embed_dim`]. The values in the representation will are not fixed binary values but rather, changing floating points allowing for fine-grained learned representations.\n","\n","* **Objective:**  Represent tokens in text that capture the intrinsic semantic relationships.\n","* **Advantages:** \n","    * Low-dimensionality while capturing relationships.\n","    * Interpretable token representations\n","* **Disadvantages:** Can be computationally intensive to precompute.\n","* **Miscellaneous:** There are lot's of pretrained embeddings to choose from but you can also train your own from scratch.\n","\n","\n"]},{"cell_type":"markdown","metadata":{"id":"nH_O4MZ294jk"},"source":["# Learning Embeddings"]},{"cell_type":"markdown","metadata":{"id":"F47IiPgUupAk"},"source":["We can learn embeddings by creating our models in PyTorch but first, we're going to use a library that specializes in embeddings and topic modeling called [Gensim](https://radimrehurek.com/gensim/). "]},{"cell_type":"code","metadata":{"id":"_pZljlaCgG6Y","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249517914,"user_tz":420,"elapsed":5623,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"0a90af35-f470-4d30-b9b4-de7fd847f632"},"source":["import nltk\n","nltk.download('punkt');\n","import numpy as np\n","import re\n","import urllib"],"execution_count":null,"outputs":[{"output_type":"stream","text":["[nltk_data] Downloading package punkt to /root/nltk_data...\n","[nltk_data]   Unzipping tokenizers/punkt.zip.\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"oektJd55gG1p"},"source":["SEED = 1234"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tqbnugiD-SW0"},"source":["# Set seed for reproducibility\n","np.random.seed(SEED)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"vF5D_nNjlx2d","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249518769,"user_tz":420,"elapsed":6439,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"d72e77e6-486f-4177-a92c-3b1884fb0539"},"source":["# Split text into sentences\n","tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')\n","book = urllib.request.urlopen(url=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/datasets/harrypotter.txt\")\n","sentences = tokenizer.tokenize(str(book.read()))\n","print (f\"{len(sentences)} sentences\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["12443 sentences\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"xWyxaKfJOomF"},"source":["def preprocess(text):\n","    \"\"\"Conditional preprocessing on our text.\"\"\"\n","    # Lower\n","    text = text.lower()\n","\n","    # Spacing and filters\n","    text = re.sub(r\"([-;;.,!?<=>])\", r\" \\1 \", text)\n","    text = re.sub('[^A-Za-z0-9]+', ' ', text) # remove non alphanumeric chars\n","    text = re.sub(' +', ' ', text)  # remove multiple spaces\n","    text = text.strip()\n","\n","    # Separate into word tokens\n","    text = text.split(\" \")\n","\n","    return text"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NsZz5jfMlx0d","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249518770,"user_tz":420,"elapsed":6411,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"2005b37e-9d95-435c-f3cd-fdf1b451da66"},"source":["# Preprocess sentences\n","print (sentences[11])\n","sentences = [preprocess(sentence) for sentence in sentences]\n","print (sentences[11])"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Snape nodded, but did not elaborate.\n","['snape', 'nodded', 'but', 'did', 'not', 'elaborate']\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"rozFTf06ji1b"},"source":["But how do we learn the embeddings the first place? The intuition behind embeddings is that the definition of a token depends on the token itself but on its context. There are several different ways of doing this:\n","\n","1. Given the word in the context, predict the target word (CBOW - continuous bag of words).\n","2. Given the target word, predict the context word (skip-gram).\n","3. Given a sequence of words, predict the next word (LM - language modeling).\n","\n","All of these approaches involve create data to train our model on. Every word in a sentence becomes the target word and the context words are determines by a window. In the image below (skip-gram), the window size is 2 (2 words to the left and right of the target word). We repeat this for every sentence in our corpus and this results in our training data for the unsupervised task. This in an unsupervised learning technique since we don't have official labels for contexts. The idea is that similar target words will appear with similar contexts and we can learn this relationship by repeatedly training our mode with (context, target) pairs.\n","\n","<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/python/skipgram.png\" width=\"600\">\n","</div>\n","\n","We can learn embeddings using any of these approaches above and some work better than others. You can inspect the learned embeddings but the best way to choose an approach is to empirically validate the performance on a supervised task."]},{"cell_type":"markdown","metadata":{"id":"No6c943C-P7o"},"source":["## Word2Vec"]},{"cell_type":"markdown","metadata":{"id":"VeszvcMOji4u"},"source":["When we have large vocabularies to learn embeddings for, things can get complex very quickly. Recall that the backpropagation with softmax updates both the correct and incorrect class weights. This becomes a massive computation for every backwas pass we do so a workaround is to use [negative sampling](http://mccormickml.com/2017/01/11/word2vec-tutorial-part-2-negative-sampling/) which only updates the correct class and a few arbitrary incorrect classes (`NEGATIVE_SAMPLING`=20). We're able to do this because of the large amount of training data where we'll see the same word as the target class multiple times.\n","\n"]},{"cell_type":"code","metadata":{"id":"TqKCr--k-f9e"},"source":["import gensim\n","from gensim.models import KeyedVectors\n","from gensim.models import Word2Vec"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ufU-9l_W-QKj"},"source":["EMBEDDING_DIM = 100\n","WINDOW = 5\n","MIN_COUNT = 3 # Ignores all words with total frequency lower than this\n","SKIP_GRAM = 1 # 0 = CBOW\n","NEGATIVE_SAMPLING = 20"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Ha3I2oSsmhJa","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249530492,"user_tz":420,"elapsed":18082,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"9b3e1430-3430-438c-be18-6112815397aa"},"source":["# Super fast because of optimized C code under the hood\n","w2v = Word2Vec(\n","    sentences=sentences, size=EMBEDDING_DIM, \n","    window=WINDOW, min_count=MIN_COUNT, \n","    sg=SKIP_GRAM, negative=NEGATIVE_SAMPLING)\n","print (w2v)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Word2Vec(vocab=4937, size=100, alpha=0.025)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"Cl6oJv8jmhHE","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249530493,"user_tz":420,"elapsed":18061,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"a72c6cee-8dd7-42ee-f233-7441369105d7"},"source":["# Vector for each word\n","w2v.wv.get_vector(\"potter\")"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["array([-0.11787166, -0.2702948 ,  0.24332453,  0.07497228, -0.5299148 ,\n","        0.17751476, -0.30183575,  0.17060578, -0.0342238 , -0.331856  ,\n","       -0.06467848,  0.02454215,  0.4524056 , -0.18918884, -0.22446074,\n","        0.04246538,  0.5784022 ,  0.12316586,  0.03419832,  0.12895502,\n","       -0.36260423,  0.06671549, -0.28563526, -0.06784113, -0.0838319 ,\n","        0.16225453,  0.24313857,  0.04139925,  0.06982274,  0.59947336,\n","        0.14201492, -0.00841052, -0.14700615, -0.51149386, -0.20590985,\n","        0.00435914,  0.04931103,  0.3382509 , -0.06798466,  0.23954925,\n","       -0.07505646, -0.50945646, -0.44729665,  0.16253233,  0.11114362,\n","        0.05604156,  0.26727834,  0.43738437, -0.2606872 ,  0.16259147,\n","       -0.28841105, -0.02349186,  0.00743417,  0.08558545, -0.0844396 ,\n","       -0.44747537, -0.30635086, -0.04186366,  0.11142804,  0.03187608,\n","        0.38674814, -0.2663519 ,  0.35415238,  0.094676  , -0.13586426,\n","       -0.35296437, -0.31428036, -0.02917303,  0.02518964, -0.59744245,\n","       -0.11500382,  0.15761602,  0.30535367, -0.06207089,  0.21460988,\n","        0.17566076,  0.46426776,  0.15573359,  0.3675553 , -0.09043553,\n","        0.2774392 ,  0.16967005,  0.32909656,  0.01422888,  0.4131812 ,\n","        0.20034142,  0.13722987,  0.10324971,  0.14308734,  0.23772323,\n","        0.2513108 ,  0.23396717, -0.10305202, -0.03343603,  0.14360961,\n","       -0.01891198,  0.11430877,  0.30017182, -0.09570111, -0.10692801],\n","      dtype=float32)"]},"metadata":{"tags":[]},"execution_count":15}]},{"cell_type":"code","metadata":{"id":"DyuLX9DTnLvM","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249530494,"user_tz":420,"elapsed":18041,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"2e8d31cf-d7cb-4d63-957b-cbfaf6341fa4"},"source":["# Get nearest neighbors (excluding itself)\n","w2v.wv.most_similar(positive=\"scar\", topn=5)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[('pain', 0.9274871349334717),\n"," ('forehead', 0.9020695686340332),\n"," ('heart', 0.8953317999839783),\n"," ('mouth', 0.8939940929412842),\n"," ('throat', 0.8922691345214844)]"]},"metadata":{"tags":[]},"execution_count":16}]},{"cell_type":"code","metadata":{"id":"YT7B0KRVTFew"},"source":["# Saving and loading\n","w2v.wv.save_word2vec_format('model.bin', binary=True)\n","w2v = KeyedVectors.load_word2vec_format('model.bin', binary=True)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"JZXVP5vfuiD5"},"source":["## FastText"]},{"cell_type":"markdown","metadata":{"id":"uvuoeWYMuqsa"},"source":["What happen's when a word doesn't exist in our vocabulary? We could assign an UNK token which is used for all OOV (out of vocabulary) words or we could use [FastText](https://radimrehurek.com/gensim/models/fasttext.html), which uses character-level n-grams to embed a word. This helps embed rare words, misspelled words, and also words that don't exist in our corpus but are similar to words in our corpus."]},{"cell_type":"code","metadata":{"id":"fVg3PBeD-kAa"},"source":["from gensim.models import FastText"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"eTNW4Mfgrpo0","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249545432,"user_tz":420,"elapsed":32935,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"d7aa13b8-5beb-48dc-f661-d841edde385a"},"source":["# Super fast because of optimized C code under the hood\n","ft = FastText(sentences=sentences, size=EMBEDDING_DIM, \n","              window=WINDOW, min_count=MIN_COUNT, \n","              sg=SKIP_GRAM, negative=NEGATIVE_SAMPLING)\n","print (ft)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["FastText(vocab=4937, size=100, alpha=0.025)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"LbA4vU5uxiw3"},"source":["# This word doesn't exist so the word2vec model will error out\n","# w2v.wv.most_similar(positive=\"scarring\", topn=5)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"eRG30aE4sMjt","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249545434,"user_tz":420,"elapsed":32903,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"65bdc527-c62e-4da7-c3c1-5136695855ec"},"source":["# FastText will use n-grams to embed an OOV word\n","ft.wv.most_similar(positive=\"scarring\", topn=5)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[('sparkling', 0.9785991907119751),\n"," ('coiling', 0.9770463705062866),\n"," ('watering', 0.9759057760238647),\n"," ('glittering', 0.9756022095680237),\n"," ('dazzling', 0.9755154848098755)]"]},"metadata":{"tags":[]},"execution_count":21}]},{"cell_type":"code","metadata":{"id":"7SE5fPMUnLyP"},"source":["# Save and loading\n","ft.wv.save('model.bin')\n","ft = KeyedVectors.load('model.bin')"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"67UmjtK0pF9X"},"source":["# Pretrained embeddings"]},{"cell_type":"markdown","metadata":{"id":"Xm1GPn4spF6x"},"source":["We can learn embeddings from scratch using one of the approaches above but we can also leverage pretrained embeddings that have been trained on millions of documents. Popular ones include [Word2Vec](https://www.tensorflow.org/tutorials/text/word2vec) (skip-gram) or [GloVe](https://nlp.stanford.edu/projects/glove/) (global word-word co-occurrence). We can validate that these embeddings captured meaningful semantic relationships by confirming them."]},{"cell_type":"code","metadata":{"id":"Hh42Mb4lLbuB"},"source":["from gensim.scripts.glove2word2vec import glove2word2vec\n","from io import BytesIO\n","import matplotlib.pyplot as plt\n","from sklearn.decomposition import PCA\n","from urllib.request import urlopen\n","from zipfile import ZipFile"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"m9gxHJA9M8hK"},"source":["# Arguments\n","EMBEDDING_DIM = 100"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ANfQHxGrMKTe"},"source":["def plot_embeddings(words, embeddings, pca_results):\n","    for word in words:\n","        index = embeddings.index2word.index(word)\n","        plt.scatter(pca_results[index, 0], pca_results[index, 1])\n","        plt.annotate(word, xy=(pca_results[index, 0], pca_results[index, 1]))\n","    plt.show()"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"ZW9Qtkz3LfdY","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250237848,"user_tz":420,"elapsed":427321,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"d6953cc6-b41a-453f-c632-a234776e075b"},"source":["# Unzip the file (may take ~3-5 minutes)\n","resp = urlopen('http://nlp.stanford.edu/data/glove.6B.zip')\n","zipfile = ZipFile(BytesIO(resp.read()))\n","zipfile.namelist()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['glove.6B.50d.txt',\n"," 'glove.6B.100d.txt',\n"," 'glove.6B.200d.txt',\n"," 'glove.6B.300d.txt']"]},"metadata":{"tags":[]},"execution_count":69}]},{"cell_type":"code","metadata":{"id":"bWnVBrOaLjIC","colab":{"base_uri":"https://localhost:8080/","height":35},"executionInfo":{"status":"ok","timestamp":1608250240804,"user_tz":420,"elapsed":430134,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"093636f1-3952-4328-e7eb-1b308b62cc9c"},"source":["# Write embeddings to file\n","embeddings_file = 'glove.6B.{0}d.txt'.format(EMBEDDING_DIM)\n","zipfile.extract(embeddings_file)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"},"text/plain":["'/content/glove.6B.100d.txt'"]},"metadata":{"tags":[]},"execution_count":70}]},{"cell_type":"code","metadata":{"id":"qFLyIqIxrUIs","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250240806,"user_tz":420,"elapsed":429968,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"bc8a3ead-34e6-46c0-b5f7-6471cf523acd"},"source":["# Preview of the GloVe embeddings file\n","with open(embeddings_file, 'r') as fp:\n","    line = next(fp)\n","    values = line.split()\n","    word = values[0]\n","    embedding = np.asarray(values[1:], dtype='float32')\n","    print (f\"word: {word}\")\n","    print (f\"embedding:\\n{embedding}\")\n","    print (f\"embedding dim: {len(embedding)}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["word: the\n","embedding:\n","[-0.038194 -0.24487   0.72812  -0.39961   0.083172  0.043953 -0.39141\n","  0.3344   -0.57545   0.087459  0.28787  -0.06731   0.30906  -0.26384\n"," -0.13231  -0.20757   0.33395  -0.33848  -0.31743  -0.48336   0.1464\n"," -0.37304   0.34577   0.052041  0.44946  -0.46971   0.02628  -0.54155\n"," -0.15518  -0.14107  -0.039722  0.28277   0.14393   0.23464  -0.31021\n","  0.086173  0.20397   0.52624   0.17164  -0.082378 -0.71787  -0.41531\n","  0.20335  -0.12763   0.41367   0.55187   0.57908  -0.33477  -0.36559\n"," -0.54857  -0.062892  0.26584   0.30205   0.99775  -0.80481  -3.0243\n","  0.01254  -0.36942   2.2167    0.72201  -0.24978   0.92136   0.034514\n","  0.46745   1.1079   -0.19358  -0.074575  0.23353  -0.052062 -0.22044\n","  0.057162 -0.15806  -0.30798  -0.41625   0.37972   0.15006  -0.53212\n"," -0.2055   -1.2526    0.071624  0.70565   0.49744  -0.42063   0.26148\n"," -1.538    -0.30223  -0.073438 -0.28312   0.37104  -0.25217   0.016215\n"," -0.017099 -0.38984   0.87424  -0.72569  -0.51058  -0.52028  -0.1459\n","  0.8278    0.27062 ]\n","embedding dim: 100\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"9eD5doqFLjFY","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250242031,"user_tz":420,"elapsed":430805,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"59310cfa-e600-4ce5-ddca-28b308591241"},"source":["# Save GloVe embeddings to local directory in word2vec format\n","word2vec_output_file = '{0}.word2vec'.format(embeddings_file)\n","glove2word2vec(embeddings_file, word2vec_output_file)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["(400000, 100)"]},"metadata":{"tags":[]},"execution_count":72}]},{"cell_type":"code","metadata":{"id":"To4sx_1iMCX0"},"source":["# Load embeddings (may take a minute)\n","glove = KeyedVectors.load_word2vec_format(word2vec_output_file, binary=False)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"UEhBhvgHMEH9","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250277481,"user_tz":420,"elapsed":465141,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"85582e56-ccd1-40eb-e7d4-7d376adf17ac"},"source":["# (king - man) + woman = ?\n","glove.most_similar(positive=['woman', 'king'], negative=['man'], topn=5)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[('queen', 0.7698541283607483),\n"," ('monarch', 0.6843380928039551),\n"," ('throne', 0.6755735874176025),\n"," ('daughter', 0.6594556570053101),\n"," ('princess', 0.6520534753799438)]"]},"metadata":{"tags":[]},"execution_count":74}]},{"cell_type":"code","metadata":{"id":"xR94AICkMEFV","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250277482,"user_tz":420,"elapsed":464638,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"209fcf3f-6bb8-4593-872d-aac59d719431"},"source":["# Get nearest neighbors (exlcusing itself)\n","glove.wv.most_similar(positive=\"goku\", topn=5)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n","  \n"],"name":"stderr"},{"output_type":"execute_result","data":{"text/plain":["[('gohan', 0.7246542572975159),\n"," ('bulma', 0.6497020125389099),\n"," ('raistlin', 0.6443604230880737),\n"," ('skaar', 0.6316742897033691),\n"," ('guybrush', 0.6231324672698975)]"]},"metadata":{"tags":[]},"execution_count":75}]},{"cell_type":"code","metadata":{"id":"gseqjBmzMECq","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250280774,"user_tz":420,"elapsed":467358,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"491e696c-7bb1-4eb0-cbc4-aaaa7f134afa"},"source":["# Reduce dimensionality for plotting\n","X = glove[glove.wv.vocab]\n","pca = PCA(n_components=2)\n","pca_results = pca.fit_transform(X)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:2: DeprecationWarning: Call to deprecated `wv` (Attribute will be removed in 4.0.0, use self instead).\n","  \n"],"name":"stderr"}]},{"cell_type":"code","metadata":{"id":"LFQWGyncMHgK","colab":{"base_uri":"https://localhost:8080/","height":265},"executionInfo":{"status":"ok","timestamp":1608250280988,"user_tz":420,"elapsed":467060,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"33817b31-e29e-460f-cfad-37aabb53c884"},"source":["# Visualize\n","plot_embeddings(\n","    words=[\"king\", \"queen\", \"man\", \"woman\"], embeddings=glove,  \n","    pca_results=pca_results)"],"execution_count":null,"outputs":[{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAXsAAAD4CAYAAAANbUbJAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAV1ElEQVR4nO3df5BV9Znn8fdDg90oDuyGxhDQgLXgD2ykm8Yag8QWs4KiMFNJHCmyGyeKqcSJxkQ0Zo2yWkllgrX+yGY0OKFQU6ImKgWoI/5ARR0HGkUmoAiLPSvoCLqmIwKRxu/+0W1Pg0Df7r59L7fP+1XVVfd8z/ec8zx1qY/Hc869N1JKSJJ6tl7FLkCS1P0Me0nKAMNekjLAsJekDDDsJSkDehfrwAMHDkzDhg0r1uElqSStWrXqvZRSZUe3K1rYDxs2jPr6+mIdXpJKUkT8W2e28zKOJGWAYV/iGhoaOOmkk/Yaq6+v57LLLitSRZIORUW7jKPuU1tbS21tbbHLkHQI8cy+B9m0aRPV1dXMmTOHc889F4DZs2fzrW99i7q6Oo499lhuu+221vk33ngjxx13HKeddhrTp0/npptuKlbpkrqZZ/Y9xPr167nggguYP38+H3zwAc8++2zrutdff51ly5bx4Ycfctxxx/Gd73yH1atX8+CDD/Lqq6+ye/duampqGDt2bBE7kNSdDPsStPCVLcx5fD1v/3En/zk1svmdd5k2bRoPPfQQJ554Is8888xe86dMmUJ5eTnl5eUMGjSId999lxdeeIFp06ZRUVFBRUUF5513XnGakVQQXsYpMQtf2cI1D/0rW/64kwS8+6dd7KCciv90FM8///x+tykvL299XVZWRlNTU4GqlXSoMOxLzJzH17Nz9569B3uVUXH2Vdx9993ce++9Oe1n/PjxLF68mF27drF9+3aWLFnSDdVKOlQY9iXm7T/u3O/4uztgyZIl3HzzzfzpT39qdz/jxo1j6tSpjB49mrPPPpuqqir69++f73IlHSKiWD9eUltbm/wEbceN//nTbNlP4A8Z0JcXfjSxQ/vavn07/fr1Y8eOHXz5y19m7ty51NTU5KtUSd0gIlallDr8bHW7Z/YRMS8itkbEH9qZNy4imiLiax0tQrmbNek4+vYp22usb58yZk06rsP7uuSSSxgzZgw1NTV89atfNeilHiyXp3HmA/8buPtAEyKiDPh7YGl+ytKB/FX1EIDWp3G+MKAvsyYd1zreEble35dU+toN+5TScxExrJ1p3wMeBMbloSa146+qh3Qq3CVlV5dv0EbEEOCvgdtzmHtJRNRHRP22bdu6emhJUo7y8TTOLcDVKaVP2puYUpqbUqpNKdVWVnb465glSZ2Uj0/Q1gL3RQTAQOCciGhKKS3Mw74lSXnQ5bBPKQ3/9HVEzAeWGPSSdGhpN+wjYgFQBwyMiM3A9UAfgJTSHd1anSQpL3J5Gmd6rjtLKV3YpWokSd3Cr0uQpAww7CUpAwx7ScoAw16SMsCwl6QMMOwlKQMMe0nKAMNekjLAsJekDDDsJSkDDHtJygDDXpIywLCXpAww7CUpAwx7ScoAw16SMsCwl6QMMOwlKQMMe0nKAMNekjLAsJekDDDsJSkDDHtJygDDXpIywLCXpAww7CUpAwx7ScoAw16SMsCwl6QMMOwlKQMMe0nKAMNekjLAsJekDDDsJSkDDHtJyoB2wz4i5kXE1oj4wwHWz4iINRHxrxHxYkScnP8yJUldkcuZ/Xxg8kHWvwmcnlKqAm4E5uahLklSHvVub0JK6bmIGHaQ9S+2WXwJGNr1siRJ+ZTva/YXAY8daGVEXBIR9RFRv23btjwfWpJ0IHkL+4g4g+awv/pAc1JKc1NKtSml2srKynwdWpLUjnYv4+QiIkYD/wicnVJ6Px/7lCTlT5fP7CPiGOAh4L+llN7oekmSpHxr98w+IhYAdcDAiNgMXA/0AUgp3QFcB3wO+IeIAGhKKdV2V8GSpI7L5Wmc6e2svxi4OG8VSZLyzk/QSlIGGPaSlAGGvSRlgGEvSRlg2EtSBhj2kpQBhr0kZYBhL0kZYNhLUgYY9pKUAYa9JGWAYS9JGWDYS1IGGPaSlAGGvSRlgGEvSRlg2EtSBhj2kpQBhr0kFUFDQwPHH388F154ISNHjmTGjBk8+eSTjB8/nhEjRrBixQpWrFjBqaeeSnV1NV/60pdYv349ABFxYUQ8FBH/FBEbIuIX7R2v3d+glSR1j40bN/K73/2OefPmMW7cOO69916ef/55Fi1axM9+9jPuvvtuli9fTu/evXnyySf58Y9/3HbzMUA18GdgfUT8MqX01oGOZdhLUpEMHz6cqqoqAEaNGsWZZ55JRFBVVUVDQwONjY1885vfZMOGDUQEu3fvbrv5UymlRoCIWAd8ETDsJanYHtn0CLe+fCv//tG/0/+j/nwcH7eu69WrF+Xl5a2vm5qa+MlPfsIZZ5zBww8/TENDA3V1dW139+c2r/fQTp57zV6SCuCRTY8w+8XZvPPROyQSW3dsZeuOrTyy6ZEDbtPY2MiQIUMAmD9/fpeOb9hLUgHc+vKt7Nqza6+xROLWl2894DZXXXUV11xzDdXV1TQ1NXXp+JFS6tIOOqu2tjbV19cX5diSVGij7xpN4rN5GwRrvrkm5/1ExKqUUm1Hj++ZvSQVwOeP+HyHxvPNsJekAri85nIqyir2Gqsoq+DymssLcnyfxpGkAphy7BSA1qdxPn/E57m85vLW8e5m2EtSgUw5dkrBwn1fXsaRpAww7CUpAwx7ScoAw16SMsCwl6QMMOwlKQPaDfuImBcRWyPiDwdYHxFxW0RsjIg1EVGT/zIlSV2Ry5n9fGDyQdafDYxo+bsEuL3rZUmS8qndsE8pPQf8v4NMmQbcnZq9BAyIiMH5KlCS1HX5uGY/hL1/HWVzy9hnRMQlEVEfEfXbtm3Lw6ElSbko6A3alNLclFJtSqm2srKykIeWpEzLR9hvAY5uszy0ZUySdIjIR9gvAv57y1M5fwk0ppTeycN+JUl50u63XkbEAqAOGBgRm4HrgT4AKaU7gEeBc4CNwA7gb7urWElS57Qb9iml6e2sT8CleatIkpR3foJWkjLAsJekDDDsJSkDDHtJygDDXpIywLCXpAww7CUpAwx7ScoAw16SMsCwl6QMMOwlKQMMe0nKAMNekjLAsJekDDDsJSkDDHtJygDDXpIywLCXpAww7CUpAwx7ScoAw16SMsCwl6QMMOwlKQMMe0nKAMNekjLAsJekDDDsJSkDDHtJygDDXpIywLCXpAww7CUpAwx7ScoAw16SMsCwl6QMMOwlKQNyCvuImBwR6yNiY0T8aD/rj4mIZRHxSkSsiYhz8l+qJKmz2g37iCgDfgWcDZwITI+IE/eZdi3wQEqpGrgA+Id8FypJ6rxczuxPATamlDallD4G7gOm7TMnAX/R8ro/8Hb+SpQkdVUuYT8EeKvN8uaWsbZmA9+IiM3Ao8D39rejiLgkIuojon7btm2dKFeS1Bn5ukE7HZifUhoKnAPcExGf2XdKaW5KqTalVFtZWZmnQ0uS2pNL2G8Bjm6zPLRlrK2LgAcAUkr/DFQAA/NRoCSp63IJ+5XAiIgYHhGH0XwDdtE+c/4vcCZARJxAc9h7nUaSDhHthn1KqQn4O+Bx4DWan7pZGxE3RMTUlmk/BGZGxKvAAuDClFLqrqIlSR3TO5dJKaVHab7x2nbsujav1wHj81uaJClf/AStJGWAYS9JGWDYS1IGGPaSlAGGvSRlgGEvSRlg2EtSBhj2kpQBhr0kZYBhL0kZYNhLUgYY9pKUAYa9JGWAYS9JGWDYSwU0Z84cbrvtNgCuuOIKJk6cCMDTTz/NjBkzWLBgAVVVVZx00klcffXVrdv169ePWbNmMWrUKL7yla+wYsUK6urqOPbYY1m0qPm3hBoaGpgwYQI1NTXU1NTw4osvAvDMM89QV1fH1772NY4//nhmzJiBPzeRPYa9VEATJkxg+fLlANTX17N9+3Z2797N8uXLGTlyJFdffTVPP/00q1evZuXKlSxcuBCAjz76iIkTJ7J27VqOPPJIrr32Wp544gkefvhhrruu+aclBg0axBNPPMHLL7/M/fffz2WXXdZ63FdeeYVbbrmFdevWsWnTJl544YXCN6+iMuylAmhcvJgNE8/k8Av/lpeWLOGt+++nvLycU089lfr6epYvX86AAQOoq6ujsrKS3r17M2PGDJ577jkADjvsMCZPngxAVVUVp59+On369KGqqoqGhgYAdu/ezcyZM6mqquLrX/8669ataz3+KaecwtChQ+nVqxdjxoxp3UbZYdhL3axx8WLe+cl1NL39Nn2AIb16cccPfkjNwIFMmDCBZcuWsXHjRoYNG3bAffTp04eIAKBXr16Ul5e3vm5qagLg5ptv5qijjuLVV1+lvr6ejz/+uHX7T+cDlJWVtW6j7DDspW629eZbSLt2tS6P7duXeVvf5cTX1zNhwgTuuOMOqqurOeWUU3j22Wd577332LNnDwsWLOD000/P+TiNjY0MHjyYXr16cc8997Bnz57uaEclyrCXulnTO+/stTy27+G819RE1c6dHHXUUVRUVDBhwgQGDx7Mz3/+c8444wxOPvlkxo4dy7Rp03I+zne/+13uuusuTj75ZF5//XWOOOKIfLeiEhbFuitfW1ub6uvri3JsqZA2TDyTprff/sx47y98gRFPP1WEilTKImJVSqm2o9t5Zi91s0FXfJ+oqNhrLCoqGHTF94tUkbKod7ELkHq6/uedBzRfu2965x16Dx7MoCu+3zouFYJhLxVA//POM9xVVF7GkaQMMOwlKQMMe0nKAMNekjLAsJekDDDsJSkDDHtJygDDXpIywLCXpAww7CUpAwx7ScqAnMI+IiZHxPqI2BgRPzrAnPMjYl1ErI2Ie/NbpiSpK9r9IrSIKAN+BfxXYDOwMiIWpZTWtZkzArgGGJ9S+iAiBnVXwZKkjsvlzP4UYGNKaVNK6WPgPmDfn8+ZCfwqpfQBQEppa37LlCR1RS5hPwR4q83y5paxtkYCIyPihYh4KSIm729HEXFJRNRHRP22bds6V7EkqcPydYO2NzACqAOmA3dGxIB9J6WU5qaUalNKtZWVlXk6tCSpPbmE/Rbg6DbLQ1vG2toMLEop7U4pvQm8QXP4S5IOAbmE/UpgREQMj4jDgAuARfvMWUjzWT0RMZDmyzqb8linJKkL2g37lFIT8HfA48BrwAMppbURcUNETG2Z9jjwfkSsA5YBs1JK73dX0ZKkjomUUlEOXFtbm+rr64tybEkqVRGxKqVU29Ht/AStJGWAYS9JGWDYS1IGGPaSlAGGvSRlgGEvSRlQ0mH/05/+lJEjR3Laaacxffp0brrpJurq6vj0kc733nuPYcOGAbBnzx5mzZrFuHHjGD16NL/+9a9b9zNnzpzW8euvvx6AhoYGTjjhBGbOnMmoUaM466yz2LlzZ8F7lKR8KNmwX7VqFffddx+rV6/m0UcfZeXKlQed/5vf/Ib+/fuzcuVKVq5cyZ133smbb77J0qVL2bBhAytWrGD16tWsWrWK5557DoANGzZw6aWXsnbtWgYMGMCDDz5YiNYkKe/a/T77Q8qaB+CpG6BxM8tX9+WvvzSeww8/HICpU6cedNOlS5eyZs0afv/73wPQ2NjIhg0bWLp0KUuXLqW6uhqA7du3s2HDBo455hiGDx/OmDFjABg7diwNDQ3d15skdaPSCfs1D8Diy2B3y6WUXR/AG//UPD76/NZpvXv35pNPPmmesmtX63hKiV/+8pdMmjRpr90+/vjjXHPNNXz729/ea7yhoYHy8vLW5bKyMi/jSCpZpXMZ56kb/iPogS9/sTcL1+1k52Oz+fDDD1m8eDEAw4YNY9WqVQCtZ/EAkyZN4vbbb2f37t0AvPHGG3z00UdMmjSJefPmsX37dgC2bNnC1q3+9oqknqV0zuwbN++1WDO4jL8Z1YeTf7GeQYvPZty4cQBceeWVnH/++cydO5cpU6a0zr/44otpaGigpqaGlBKVlZUsXLiQs846i9dee41TTz0VgH79+vHb3/6WsrKywvUmSd2sdL4I7eaToPGtz473Pxqu+AOzZ8+mX79+XHnllfkrUpIOMT3/i9DOvA769N17rE/f5nFJ0kGVzmWcT2/CtjyNQ/+hzUHfMj579uzi1SZJh7jSCXtoDvY2T95IknJTOpdxJEmdZthLUgYY9pKUAYa9JGWAYS9JGVC0D1VFxDbg37pp9wOB97pp38XWU3vrqX1Bz+2tp/YFh3ZvX0wpVXZ0o6KFfXeKiPrOfMKsFPTU3npqX9Bze+upfUHP7M3LOJKUAYa9JGVATw37ucUuoBv11N56al/Qc3vrqX1BD+ytR16zlyTtraee2UuS2jDsJSkDSjbsI6IiIlZExKsRsTYi/ucB5p0fEeta5txb6Do7I5feIuKYiFgWEa9ExJqIOKcYtXZGRJS11L1kP+vKI+L+iNgYEf8SEcMKX2HntNPXD1r+Ha6JiKci4ovFqLGzDtZbmzlfjYgUESXzyGJ7fZVifhxIaX3F8d7+DExMKW2PiD7A8xHxWErppU8nRMQI4BpgfErpg4gYVKxiO6jd3oBrgQdSSrdHxInAo8CwItTaGZcDrwF/sZ91FwEfpJT+S0RcAPw98DeFLK4LDtbXK0BtSmlHRHwH+AWl0xccvDci4siWOf9SyKLy4IB9lXB+7FfJntmnZttbFvu0/O17t3km8KuU0gct25TEL4nn2FviP/6B9gfeLlB5XRIRQ4EpwD8eYMo04K6W178HzoyIKERtXdFeXymlZSmlHS2LLwFDC1VbV+XwngHcSPN/mHcVpKg8yKGvksyPAynZsIfW/wVbDWwFnkgp7XtWMRIYGREvRMRLETG58FV2Tg69zQa+ERGbaT6r/16BS+ysW4CrgE8OsH4I8BZASqkJaAQ+V5jSuqS9vtq6CHise8vJq4P2FhE1wNEppUcKWlXXtfeelWx+7E9Jh31KaU9KaQzNZ0mnRMRJ+0zpDYwA6oDpwJ0RMaCwVXZODr1NB+anlIYC5wD3RMQh/X5GxLnA1pTSqmLXkk8d6SsivgHUAnO6vbA8aK+3ln9z/wv4YUEL66Ic37OSzY/9OaTDIVcppT8Cy4B9/8u7GViUUtqdUnoTeIPmN69kHKS3i4AHWub8M1BB85c3HcrGA1MjogG4D5gYEb/dZ84W4GiAiOhN8yWq9wtZZCfk0hcR8RXgfwBTU0p/LmyJndZeb0cCJwHPtMz5S2BRCdykzeU9K/n82EtKqST/gEpgQMvrvsBy4Nx95kwG7mp5PZDmywOfK3bteertMeDCltcn0HzNPopdewd6rAOW7Gf8UuCOltcX0HwTuuj15qGvauD/ACOKXWO+e9tnzjM034guer15eM9KMj8O9FfKZ/aDgWURsQZYSfN17SURcUNETG2Z8zjwfkSso/nseFZK6VA/S4TcevshMDMiXgUW0Bz8Jflx6H36+g3wuYjYCPwA+FHxKuuaffqaA/QDfhcRqyNiURFL67J9eusxekh+7JdflyBJGVDKZ/aSpBwZ9pKUAYa9JGWAYS9JGWDYS1IGGPaSlAGGvSRlwP8HYu41jSQJwzYAAAAASUVORK5CYII=\n","text/plain":["<Figure size 432x288 with 1 Axes>"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"code","metadata":{"id":"MzrZ2_RBMHdn","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250280990,"user_tz":420,"elapsed":466243,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"1cabc6da-79d8-485b-d6d6-c85994e12247"},"source":["# Bias in embeddings\n","glove.most_similar(positive=['woman', 'doctor'], negative=['man'], topn=5)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["[('nurse', 0.7735227346420288),\n"," ('physician', 0.7189429998397827),\n"," ('doctors', 0.6824328303337097),\n"," ('patient', 0.6750682592391968),\n"," ('dentist', 0.6726033687591553)]"]},"metadata":{"tags":[]},"execution_count":78}]},{"cell_type":"markdown","metadata":{"id":"EE1kCvwnkBPc"},"source":["# Set up"]},{"cell_type":"code","metadata":{"id":"m_DIRj8G5uOC"},"source":["import numpy as np\n","import pandas as pd\n","import random\n","import torch\n","import torch.nn as nn"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"5tdPACZf5uTo"},"source":["SEED = 1234"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"NMt7hJuB5uXN"},"source":["def set_seeds(seed=1234):\n","    \"\"\"Set seeds for reproducibility.\"\"\"\n","    np.random.seed(seed)\n","    random.seed(seed)\n","    torch.manual_seed(seed)\n","    torch.cuda.manual_seed(seed)\n","    torch.cuda.manual_seed_all(seed) # multi-GPU# Set seeds for reproducibility\n","set_seeds(seed=SEED)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"mQzlWknv5ua6"},"source":["# Set seeds for reproducibility\n","set_seeds(seed=SEED)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"M1-vnM-P8i_P"},"source":["# Set device\n","cuda = True\n","device = torch.device('cuda' if (\n","    torch.cuda.is_available() and cuda) else 'cpu')\n","torch.set_default_tensor_type('torch.FloatTensor')\n","if device.type == 'cuda':\n","    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n","print (device)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"NfNfv1NTkKXa"},"source":["## Load data"]},{"cell_type":"markdown","metadata":{"id":"YiQk-yClkL3s"},"source":["We will download the [AG News dataset](http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html), which consists of 120K text samples from 4 unique classes (`Business`, `Sci/Tech`, `Sports`, `World`)"]},{"cell_type":"code","metadata":{"id":"5HLyMQBDj__P"},"source":["import numpy as np\n","import pandas as pd"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"SNfmmNBokAHI","colab":{"base_uri":"https://localhost:8080/","height":204},"executionInfo":{"status":"ok","timestamp":1608249763205,"user_tz":420,"elapsed":1037,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"35775f46-c599-4ed8-d3e7-5ee126259e90"},"source":["# Load data\n","url = \"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/datasets/news.csv\"\n","df = pd.read_csv(url, header=0) # load\n","df = df.sample(frac=1).reset_index(drop=True) # shuffle\n","df.head()"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>title</th>\n","      <th>category</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>The Spirit of St. Louis in outer space</td>\n","      <td>Sci/Tech</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>Oil Rebounds \\$1 as Heating Supplies Fall</td>\n","      <td>Business</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>Orlovsky not bitter at BC</td>\n","      <td>Sports</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>Report Says Air Force's Space Programs Improved</td>\n","      <td>Sci/Tech</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>From the Grass Roots of Iowa Comes the Thinkin...</td>\n","      <td>Sports</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>"],"text/plain":["                                               title  category\n","0             The Spirit of St. Louis in outer space  Sci/Tech\n","1          Oil Rebounds \\$1 as Heating Supplies Fall  Business\n","2                          Orlovsky not bitter at BC    Sports\n","3    Report Says Air Force's Space Programs Improved  Sci/Tech\n","4  From the Grass Roots of Iowa Comes the Thinkin...    Sports"]},"metadata":{"tags":[]},"execution_count":24}]},{"cell_type":"markdown","metadata":{"id":"Bk0a7TE2kTq4"},"source":["## Preprocessing"]},{"cell_type":"markdown","metadata":{"id":"yeTyLL8-kU9F"},"source":["We're going to clean up our input data first by doing operations such as lower text, removing stop (filler) words, filters using regular expressions, etc."]},{"cell_type":"code","metadata":{"id":"ZIrwF49UkAJ9"},"source":["import nltk\n","from nltk.corpus import stopwords\n","from nltk.stem import PorterStemmer\n","import re"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"TQR8I3HxkAMS","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249764788,"user_tz":420,"elapsed":1910,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"4a2f2da6-af99-4b87-af7e-0a64779cffe6"},"source":["nltk.download('stopwords')\n","STOPWORDS = stopwords.words('english')\n","print (STOPWORDS[:5])\n","porter = PorterStemmer()"],"execution_count":null,"outputs":[{"output_type":"stream","text":["[nltk_data] Downloading package stopwords to /root/nltk_data...\n","[nltk_data]   Unzipping corpora/stopwords.zip.\n","['i', 'me', 'my', 'myself', 'we']\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"g43E1Oa1kAO4"},"source":["def preprocess(text, stopwords=STOPWORDS):\n","    \"\"\"Conditional preprocessing on our text unique to our task.\"\"\"\n","    # Lower\n","    text = text.lower()\n","\n","    # Remove stopwords\n","    pattern = re.compile(r'\\b(' + r'|'.join(stopwords) + r')\\b\\s*')\n","    text = pattern.sub('', text)\n","\n","    # Remove words in paranthesis\n","    text = re.sub(r'\\([^)]*\\)', '', text)\n","\n","    # Spacing and filters\n","    text = re.sub(r\"([-;;.,!?<=>])\", r\" \\1 \", text)\n","    text = re.sub('[^A-Za-z0-9]+', ' ', text) # remove non alphanumeric chars\n","    text = re.sub(' +', ' ', text)  # remove multiple spaces\n","    text = text.strip()\n","\n","    return text"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tsWX-VNQkARQ","colab":{"base_uri":"https://localhost:8080/","height":35},"executionInfo":{"status":"ok","timestamp":1608249764794,"user_tz":420,"elapsed":1618,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"f11955cc-f7a7-4a6b-a4a0-fcda5de3cff9"},"source":["# Sample\n","text = \"Great week for the NYSE!\"\n","preprocess(text=text)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"application/vnd.google.colaboratory.intrinsic+json":{"type":"string"},"text/plain":["'great week nyse'"]},"metadata":{"tags":[]},"execution_count":28}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5Up0hTP8kwJx","executionInfo":{"status":"ok","timestamp":1608249766536,"user_tz":420,"elapsed":3162,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"03d221e5-06b3-4491-c8ef-6e56e769d33e"},"source":["# Apply to dataframe\n","preprocessed_df = df.copy()\n","preprocessed_df.title = preprocessed_df.title.apply(preprocess)\n","print (f\"{df.title.values[0]}\\n\\n{preprocessed_df.title.values[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["The Spirit of St. Louis in outer space\n","\n","spirit st louis outer space\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"_F_TNV8Nk7Za"},"source":["> If you have preprocessing steps like standardization, etc. that are calculated, you need to separate the training and test set first before applying those operations. This is because we cannot apply any knowledge gained from the test set accidentally (data leak) during preprocessing/training. However for global preprocessing steps like the function above where we aren't learning anything from the data itself, we can perform before splitting the data."]},{"cell_type":"markdown","metadata":{"id":"cZaX1WFyk8gR"},"source":["## Split data"]},{"cell_type":"code","metadata":{"id":"1iy_Ej7ukwMt"},"source":["import collections\n","from sklearn.model_selection import train_test_split"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Bw5zFpt3k-wz"},"source":["TRAIN_SIZE = 0.7\n","VAL_SIZE = 0.15\n","TEST_SIZE = 0.15"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"SHutQd7Pk-zt"},"source":["def train_val_test_split(X, y, train_size):\n","    \"\"\"Split dataset into data splits.\"\"\"\n","    X_train, X_, y_train, y_ = train_test_split(X, y, train_size=TRAIN_SIZE, stratify=y)\n","    X_val, X_test, y_val, y_test = train_test_split(X_, y_, train_size=0.5, stratify=y_)\n","    return X_train, X_val, X_test, y_train, y_val, y_test"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"SOz9QSm5lBH7"},"source":["# Data\n","X = preprocessed_df[\"title\"].values\n","y = preprocessed_df[\"category\"].values"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nU8ubwislBKk","executionInfo":{"status":"ok","timestamp":1608249766539,"user_tz":420,"elapsed":1879,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"f1127638-129f-4edb-a7c5-96f44c2fa041"},"source":["# Create data splits\n","X_train, X_val, X_test, y_train, y_val, y_test = train_val_test_split(\n","    X=X, y=y, train_size=TRAIN_SIZE)\n","print (f\"X_train: {X_train.shape}, y_train: {y_train.shape}\")\n","print (f\"X_val: {X_val.shape}, y_val: {y_val.shape}\")\n","print (f\"X_test: {X_test.shape}, y_test: {y_test.shape}\")\n","print (f\"Sample point: {X_train[0]} → {y_train[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["X_train: (84000,), y_train: (84000,)\n","X_val: (18000,), y_val: (18000,)\n","X_test: (18000,), y_test: (18000,)\n","Sample point: nba wrap neal pours 40 heat subdue wizards → Sports\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"JZhaxH8xmHAy"},"source":["## LabelEncoder"]},{"cell_type":"markdown","metadata":{"id":"fYueMyIUmHEh"},"source":["Next we'll define a `LabelEncoder` to encode our text labels into unique indices"]},{"cell_type":"code","metadata":{"id":"DsPgVemMmHJK"},"source":["import itertools"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"dlZ4w8OfmHM2"},"source":["class LabelEncoder(object):\n","    \"\"\"Label encoder for tag labels.\"\"\"\n","    def __init__(self, class_to_index={}):\n","        self.class_to_index = class_to_index\n","        self.index_to_class = {v: k for k, v in self.class_to_index.items()}\n","        self.classes = list(self.class_to_index.keys())\n","\n","    def __len__(self):\n","        return len(self.class_to_index)\n","\n","    def __str__(self):\n","        return f\"<LabelEncoder(num_classes={len(self)})>\"\n","\n","    def fit(self, y):\n","        classes = np.unique(y)\n","        for i, class_ in enumerate(classes):\n","            self.class_to_index[class_] = i\n","        self.index_to_class = {v: k for k, v in self.class_to_index.items()}\n","        self.classes = list(self.class_to_index.keys())\n","        return self\n","\n","    def encode(self, y):\n","        encoded = np.zeros((len(y)), dtype=int)\n","        for i, item in enumerate(y):\n","            encoded[i] = self.class_to_index[item]\n","        return encoded\n","\n","    def decode(self, y):\n","        classes = []\n","        for i, item in enumerate(y):\n","            classes.append(self.index_to_class[item])\n","        return classes\n","\n","    def save(self, fp):\n","        with open(fp, 'w') as fp:\n","            contents = {'class_to_index': self.class_to_index}\n","            json.dump(contents, fp, indent=4, sort_keys=False)\n","\n","    @classmethod\n","    def load(cls, fp):\n","        with open(fp, 'r') as fp:\n","            kwargs = json.load(fp=fp)\n","        return cls(**kwargs)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nA4TyhuBmHPu","executionInfo":{"status":"ok","timestamp":1608249768403,"user_tz":420,"elapsed":686,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"ada26a5c-c7e7-47bb-fd54-f603ae237160"},"source":["# Encode\n","label_encoder = LabelEncoder()\n","label_encoder.fit(y_train)\n","NUM_CLASSES = len(label_encoder)\n","label_encoder.class_to_index"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'Business': 0, 'Sci/Tech': 1, 'Sports': 2, 'World': 3}"]},"metadata":{"tags":[]},"execution_count":38}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"FjEqwySTmHSg","executionInfo":{"status":"ok","timestamp":1608249768404,"user_tz":420,"elapsed":585,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"a282ef19-441d-4f2d-cb85-7dcfd1e6218b"},"source":["# Convert labels to tokens\n","print (f\"y_train[0]: {y_train[0]}\")\n","y_train = label_encoder.encode(y_train)\n","y_val = label_encoder.encode(y_val)\n","y_test = label_encoder.encode(y_test)\n","print (f\"y_train[0]: {y_train[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["y_train[0]: Sports\n","y_train[0]: 2\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"xE5-9S6VmHVO","executionInfo":{"status":"ok","timestamp":1608249768612,"user_tz":420,"elapsed":334,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"a8cb3a16-4ef0-4fc6-b48d-c60ab4c0f13b"},"source":["# Class weights\n","counts = np.bincount(y_train)\n","class_weights = {i: 1.0/count for i, count in enumerate(counts)}\n","print (f\"counts: {counts}\\nweights: {class_weights}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["counts: [21000 21000 21000 21000]\n","weights: {0: 4.761904761904762e-05, 1: 4.761904761904762e-05, 2: 4.761904761904762e-05, 3: 4.761904761904762e-05}\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"mHRC9EfzlmP-"},"source":["## Tokenizer"]},{"cell_type":"markdown","metadata":{"id":"bkQmzdUXlmUH"},"source":["We'll define a `Tokenizer` to convert our text input data into token indices."]},{"cell_type":"code","metadata":{"id":"AcqOl3Lbk-2Q"},"source":["import json\n","from collections import Counter\n","from more_itertools import take"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"XbyIehIDl7l-"},"source":["class Tokenizer(object):\n","    def __init__(self, char_level, num_tokens=None, \n","                 pad_token='<PAD>', oov_token='<UNK>',\n","                 token_to_index=None):\n","        self.char_level = char_level\n","        self.separator = '' if self.char_level else ' '\n","        if num_tokens: num_tokens -= 2 # pad + unk tokens\n","        self.num_tokens = num_tokens\n","        self.pad_token = pad_token\n","        self.oov_token = oov_token\n","        if not token_to_index:\n","            token_to_index = {pad_token: 0, oov_token: 1}\n","        self.token_to_index = token_to_index\n","        self.index_to_token = {v: k for k, v in self.token_to_index.items()}\n","\n","    def __len__(self):\n","        return len(self.token_to_index)\n","\n","    def __str__(self):\n","        return f\"<Tokenizer(num_tokens={len(self)})>\"\n","\n","    def fit_on_texts(self, texts):\n","        if not self.char_level:\n","            texts = [text.split(\" \") for text in texts]\n","        all_tokens = [token for text in texts for token in text]\n","        counts = Counter(all_tokens).most_common(self.num_tokens)\n","        self.min_token_freq = counts[-1][1]\n","        for token, count in counts:\n","            index = len(self)\n","            self.token_to_index[token] = index\n","            self.index_to_token[index] = token\n","        return self\n","\n","    def texts_to_sequences(self, texts):\n","        sequences = []\n","        for text in texts:\n","            if not self.char_level:\n","                text = text.split(' ')\n","            sequence = []\n","            for token in text:\n","                sequence.append(self.token_to_index.get(\n","                    token, self.token_to_index[self.oov_token]))\n","            sequences.append(np.asarray(sequence))\n","        return sequences\n","\n","    def sequences_to_texts(self, sequences):\n","        texts = []\n","        for sequence in sequences:\n","            text = []\n","            for index in sequence:\n","                text.append(self.index_to_token.get(index, self.oov_token))\n","            texts.append(self.separator.join([token for token in text]))\n","        return texts\n","\n","    def save(self, fp):\n","        with open(fp, 'w') as fp:\n","            contents = {\n","                'char_level': self.char_level,\n","                'oov_token': self.oov_token,\n","                'token_to_index': self.token_to_index\n","            }\n","            json.dump(contents, fp, indent=4, sort_keys=False)\n","\n","    @classmethod\n","    def load(cls, fp):\n","        with open(fp, 'r') as fp:\n","            kwargs = json.load(fp=fp)\n","        return cls(**kwargs)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"gZgQb7cdqD6q"},"source":["> It's important that we only fit using our train data split because during inference, our model will not always know every token so it's important to replicate that scenario with our validation and test splits as well."]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"XlMH93AKl7oc","executionInfo":{"status":"ok","timestamp":1608249770260,"user_tz":420,"elapsed":612,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"3f20c3cf-b4e1-47be-d5dc-8f56a63afc4a"},"source":["# Tokenize\n","tokenizer = Tokenizer(char_level=False, num_tokens=5000)\n","tokenizer.fit_on_texts(texts=X_train)\n","VOCAB_SIZE = len(tokenizer)\n","print (tokenizer)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<Tokenizer(num_tokens=5000)>\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VT93_ZFIl7rE","executionInfo":{"status":"ok","timestamp":1608249770260,"user_tz":420,"elapsed":456,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"f02d0710-6b99-4a80-a50f-fb27bef475ca"},"source":["# Sample of tokens\n","print (take(5, tokenizer.token_to_index.items()))\n","print (f\"least freq token's freq: {tokenizer.min_token_freq}\") # use this to adjust num_tokens"],"execution_count":null,"outputs":[{"output_type":"stream","text":["[('<PAD>', 0), ('<UNK>', 1), ('39', 2), ('b', 3), ('gt', 4)]\n","least freq token's freq: 14\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Mz2XAlijl7u0","executionInfo":{"status":"ok","timestamp":1608249771770,"user_tz":420,"elapsed":1090,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"23161be4-416f-472d-94de-55c512ffbc12"},"source":["# Convert texts to sequences of indices\n","X_train = tokenizer.texts_to_sequences(X_train)\n","X_val = tokenizer.texts_to_sequences(X_val)\n","X_test = tokenizer.texts_to_sequences(X_test)\n","preprocessed_text = tokenizer.sequences_to_texts([X_train[0]])[0]\n","print (\"Text to indices:\\n\"\n","    f\"  (preprocessed) → {preprocessed_text}\\n\"\n","    f\"  (tokenized) → {X_train[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Text to indices:\n","  (preprocessed) → nba wrap neal <UNK> 40 heat <UNK> wizards\n","  (tokenized) → [ 299  359 3869    1 1648  734    1 2021]\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"581nl9EYFAsS"},"source":["# Embedding layer"]},{"cell_type":"markdown","metadata":{"id":"JbOzzfLNFCtW"},"source":["We can embed our inputs using PyTorch's [embedding layer](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding)."]},{"cell_type":"code","metadata":{"id":"1tHb3v_KH53e","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249787092,"user_tz":420,"elapsed":10907,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"cf03b218-df68-4620-96e7-f6f59f95ba12"},"source":["# Input\n","vocab_size = 10\n","x = torch.randint(high=vocab_size, size=(1,5))\n","print (x)\n","print (x.shape)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["tensor([[2, 6, 5, 2, 6]])\n","torch.Size([1, 5])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"FXUpmH7AFOJh","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249787093,"user_tz":420,"elapsed":10703,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"3c3cdf59-fd99-4e9b-831c-3e813672cfbb"},"source":["# Embedding layer\n","embeddings = nn.Embedding(embedding_dim=100, num_embeddings=vocab_size)\n","print (embeddings.weight.shape)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["torch.Size([10, 100])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"bVGWIgEGGmHn","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608249787094,"user_tz":420,"elapsed":10219,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"f510b101-fe78-490a-a148-7684ed1585f0"},"source":["# Embed the input\n","embeddings(x).shape"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["torch.Size([1, 5, 100])"]},"metadata":{"tags":[]},"execution_count":51}]},{"cell_type":"markdown","metadata":{"id":"WbO8HYjaGxZY"},"source":["Each token in the input is represented via embeddings (all out-of-vocabulary (OOV) tokens are given the embedding for `UNK` token.) In the model below, we'll see how to set these embeddings to be pretrained GloVe embeddings and how to choose whether to freeze (fixed embedding weights) those embeddings or not during training."]},{"cell_type":"markdown","metadata":{"id":"uTWMME1VmaTQ"},"source":["# Padding"]},{"cell_type":"markdown","metadata":{"id":"qu5gHg_Fmzdp"},"source":["Our inputs are all of varying length but we need each batch to be uniformly shaped. Therefore, we will use padding to make all the inputs in the batch the same length. Our padding index will be 0 (note that this is consistent with the `<PAD>` token defined in our `Tokenizer`).\n","\n","> While embedding our input tokens will create a batch of shape (`N`, `max_seq_len`, `embed_dim`) we only need to provide a 2D matrix (`N`, `max_seq_len`) for using embeddings with PyTorch."]},{"cell_type":"code","metadata":{"id":"JJE5dW33mHZn"},"source":["def pad_sequences(sequences, max_seq_len=0):\n","    \"\"\"Pad sequences to max length in sequence.\"\"\"\n","    max_seq_len = max(max_seq_len, max(len(sequence) for sequence in sequences))\n","    padded_sequences = np.zeros((len(sequences), max_seq_len))\n","    for i, sequence in enumerate(sequences):\n","        padded_sequences[i][:len(sequence)] = sequence\n","    return padded_sequences"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"5niX_T9ZmHcn","executionInfo":{"status":"ok","timestamp":1608249787095,"user_tz":420,"elapsed":6649,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"0fbea476-bc76-4462-b591-ad5946ac10df"},"source":["# 2D sequences\n","padded = pad_sequences(X_train[0:3])\n","print (padded.shape)\n","print (padded)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["(3, 8)\n","[[2.990e+02 3.590e+02 3.869e+03 1.000e+00 1.648e+03 7.340e+02 1.000e+00\n","  2.021e+03]\n"," [4.977e+03 1.000e+00 8.070e+02 0.000e+00 0.000e+00 0.000e+00 0.000e+00\n","  0.000e+00]\n"," [5.900e+01 1.213e+03 1.160e+02 4.042e+03 2.040e+02 4.190e+02 1.000e+00\n","  0.000e+00]]\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"t8LGKXCmUgzV"},"source":["# Dataset"]},{"cell_type":"markdown","metadata":{"id":"ACBKJ77TVBpi"},"source":["We're going to create Datasets and DataLoaders to be able to efficiently create batches with our data splits."]},{"cell_type":"code","metadata":{"id":"jREEFz72Hssx"},"source":["FILTER_SIZES = list(range(1, 4)) # uni, bi and tri grams"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"2K0D-vTGUgHV"},"source":["class Dataset(torch.utils.data.Dataset):\n","    def __init__(self, X, y, max_filter_size):\n","        self.X = X\n","        self.y = y\n","        self.max_filter_size = max_filter_size\n","\n","    def __len__(self):\n","        return len(self.y)\n","\n","    def __str__(self):\n","        return f\"<Dataset(N={len(self)})>\"\n","\n","    def __getitem__(self, index):\n","        X = self.X[index]\n","        y = self.y[index]\n","        return [X, y]\n","\n","    def collate_fn(self, batch):\n","        \"\"\"Processing on a batch.\"\"\"\n","        # Get inputs\n","        batch = np.array(batch, dtype=object)\n","        X = batch[:, 0]\n","        y = np.stack(batch[:, 1], axis=0)\n","\n","        # Pad sequences\n","        X = pad_sequences(X)\n","\n","        # Cast\n","        X = torch.LongTensor(X.astype(np.int32))\n","        y = torch.LongTensor(y.astype(np.int32))\n","\n","        return X, y\n","\n","    def create_dataloader(self, batch_size, shuffle=False, drop_last=False):\n","        return torch.utils.data.DataLoader(\n","            dataset=self, batch_size=batch_size, collate_fn=self.collate_fn,\n","            shuffle=shuffle, drop_last=drop_last, pin_memory=True)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tdAyMvfnUgP9","executionInfo":{"status":"ok","timestamp":1608250567232,"user_tz":420,"elapsed":745,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"161a90cc-0400-4a0e-c83a-37182cda415f"},"source":["# Create datasets\n","max_filter_size = max(FILTER_SIZES)\n","train_dataset = Dataset(X=X_train, y=y_train, max_filter_size=max_filter_size)\n","val_dataset = Dataset(X=X_val, y=y_val, max_filter_size=max_filter_size)\n","test_dataset = Dataset(X=X_test, y=y_test, max_filter_size=max_filter_size)\n","print (\"Datasets:\\n\"\n","    f\"  Train dataset:{train_dataset.__str__()}\\n\"\n","    f\"  Val dataset: {val_dataset.__str__()}\\n\"\n","    f\"  Test dataset: {test_dataset.__str__()}\\n\"\n","    \"Sample point:\\n\"\n","    f\"  X: {train_dataset[0][0]}\\n\"\n","    f\"  y: {train_dataset[0][1]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Datasets:\n","  Train dataset:<Dataset(N=84000)>\n","  Val dataset: <Dataset(N=18000)>\n","  Test dataset: <Dataset(N=18000)>\n","Sample point:\n","  X: [ 299  359 3869    1 1648  734    1 2021]\n","  y: 2\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"VeK0C3ORUgTu","executionInfo":{"status":"ok","timestamp":1608250567570,"user_tz":420,"elapsed":857,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"ad19832b-95fd-4132-b83c-edd91980b248"},"source":["# Create dataloaders\n","batch_size = 64\n","train_dataloader = train_dataset.create_dataloader(batch_size=batch_size)\n","val_dataloader = val_dataset.create_dataloader(batch_size=batch_size)\n","test_dataloader = test_dataset.create_dataloader(batch_size=batch_size)\n","batch_X, batch_y = next(iter(train_dataloader))\n","print (\"Sample batch:\\n\"\n","    f\"  X: {list(batch_X.size())}\\n\"\n","    f\"  y: {list(batch_y.size())}\\n\"\n","    \"Sample point:\\n\"\n","    f\"  X: {batch_X[0]}\\n\"\n","    f\"  y: {batch_y[0]}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Sample batch:\n","  X: [64, 9]\n","  y: [64]\n","Sample point:\n","  X: tensor([ 299,  359, 3869,    1, 1648,  734,    1, 2021,    0], device='cpu')\n","  y: 2\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"pfhjWZRD94hK"},"source":["# Model"]},{"cell_type":"markdown","metadata":{"id":"h0U0V8fViHZc"},"source":["We'll be using a convolutional neural network on top of our embedded tokens to extract meaningful spatial signal. This time, we'll be using many filter widths to act as n-gram feature extractors."]},{"cell_type":"markdown","metadata":{"id":"goDI3dSxHbgr"},"source":["Let's visualize the model's forward pass.\n","\n","1. We'll first tokenize our inputs (`batch_size`, `max_seq_len`).\n","2. Then we'll embed our tokenized inputs (`batch_size`, `max_seq_len`, `embedding_dim`).\n","3. We'll apply convolution via filters (`filter_size`, `vocab_size`, `num_filters`) followed by batch normalization. Our filters act as character level n-gram detecors. We have three different filter sizes (2, 3 and 4) and they will act as bi-gram, tri-gram and 4-gram feature extractors, respectivelyy. \n","4. We'll apply 1D global max pooling which will extract the most relevant information from the feature maps for making the decision.\n","5. We feed the pool outputs to a fully-connected (FC) layer (with dropout).\n","6. We use one more FC layer with softmax to derive class probabilities. "]},{"cell_type":"markdown","metadata":{"id":"EIheSuazHeBT"},"source":["<div align=\"left\">\n","<img src=\"https://raw.githubusercontent.com/GokuMohandas/MadeWithML/main/images/basics/embeddings/model.png\" width=\"1000\">\n","</div>"]},{"cell_type":"code","metadata":{"id":"_I3dmAFtsfy6"},"source":["import math\n","import torch.nn.functional as F"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"z1rRdLydmjdp"},"source":["EMBEDDING_DIM = 100\n","HIDDEN_DIM = 100\n","DROPOUT_P = 0.1"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"juRjat3CiShK"},"source":["class CNN(nn.Module):\n","    def __init__(self, embedding_dim, vocab_size, num_filters, \n","                 filter_sizes, hidden_dim, dropout_p, num_classes, \n","                 pretrained_embeddings=None, freeze_embeddings=False,\n","                 padding_idx=0):\n","        super(CNN, self).__init__()\n","\n","        # Filter sizes\n","        self.filter_sizes = filter_sizes\n","        \n","        # Initialize embeddings\n","        if pretrained_embeddings is None:\n","            self.embeddings = nn.Embedding(\n","                embedding_dim=embedding_dim, num_embeddings=vocab_size,\n","                padding_idx=padding_idx)\n","        else:\n","            pretrained_embeddings = torch.from_numpy(pretrained_embeddings).float()\n","            self.embeddings = nn.Embedding(\n","                embedding_dim=embedding_dim, num_embeddings=vocab_size,\n","                padding_idx=padding_idx, _weight=pretrained_embeddings)\n","        \n","        # Freeze embeddings or not\n","        if freeze_embeddings:\n","            self.embeddings.weight.requires_grad = False\n","        \n","        # Conv weights\n","        self.conv = nn.ModuleList(\n","            [nn.Conv1d(in_channels=embedding_dim, \n","                       out_channels=num_filters, \n","                       kernel_size=f) for f in filter_sizes])\n","     \n","        # FC weights\n","        self.dropout = nn.Dropout(dropout_p)\n","        self.fc1 = nn.Linear(num_filters*len(filter_sizes), hidden_dim)\n","        self.fc2 = nn.Linear(hidden_dim, num_classes)\n","\n","    def forward(self, inputs, channel_first=False, apply_softmax=False):\n","        \n","        # Embed\n","        x_in, = inputs\n","        x_in = self.embeddings(x_in)\n","\n","        # Rearrange input so num_channels is in dim 1 (N, C, L)\n","        if not channel_first:\n","            x_in = x_in.transpose(1, 2)\n","            \n","        # Conv outputs\n","        z = []\n","        max_seq_len = x_in.shape[2]\n","        for i, f in enumerate(self.filter_sizes):\n","            # `SAME` padding\n","            padding_left = int((self.conv[i].stride[0]*(max_seq_len-1) - max_seq_len + self.filter_sizes[i])/2)\n","            padding_right = int(math.ceil((self.conv[i].stride[0]*(max_seq_len-1) - max_seq_len + self.filter_sizes[i])/2))\n","\n","            # Conv + pool\n","            _z = self.conv[i](F.pad(x_in, (padding_left, padding_right)))\n","            _z = F.max_pool1d(_z, _z.size(2)).squeeze(2)\n","            z.append(_z)\n","        \n","        # Concat conv outputs\n","        z = torch.cat(z, 1)\n","\n","        # FC layers\n","        z = self.fc1(z)\n","        z = self.dropout(z)\n","        y_pred = self.fc2(z)\n","        \n","        if apply_softmax:\n","            y_pred = F.softmax(y_pred, dim=1)\n","        return y_pred"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"QBmYu6wjkgf0"},"source":["# GloVe embeddings"]},{"cell_type":"markdown","metadata":{"id":"RFRaj2AUojN5"},"source":["We're going create some utility functions to be able to load the pretrained GloVe embeddings into our Embeddings layer."]},{"cell_type":"code","metadata":{"id":"x9uev5AGsuqq"},"source":["def load_glove_embeddings(embeddings_file):\n","    \"\"\"Load embeddings from a file.\"\"\"\n","    embeddings = {}\n","    with open(embeddings_file, \"r\") as fp:\n","        for index, line in enumerate(fp):\n","            values = line.split()\n","            word = values[0]\n","            embedding = np.asarray(values[1:], dtype='float32')\n","            embeddings[word] = embedding\n","    return embeddings"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"tQHD-ThwWnjD"},"source":["def make_embeddings_matrix(embeddings, word_index, embedding_dim):\n","    \"\"\"Create embeddings matrix to use in Embedding layer.\"\"\"\n","    embedding_matrix = np.zeros((len(word_index), embedding_dim))\n","    for word, i in word_index.items():\n","        embedding_vector = embeddings.get(word)\n","        if embedding_vector is not None:\n","            embedding_matrix[i] = embedding_vector\n","    return embedding_matrix"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"9WxP2GR3LmrO","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250578680,"user_tz":420,"elapsed":10090,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"9ed4eca0-bec3-46f3-a24e-4f767815f406"},"source":["# Create embeddings\n","embeddings_file = 'glove.6B.{0}d.txt'.format(EMBEDDING_DIM)\n","glove_embeddings = load_glove_embeddings(embeddings_file=embeddings_file)\n","embedding_matrix = make_embeddings_matrix(\n","    embeddings=glove_embeddings, word_index=tokenizer.token_to_index, \n","    embedding_dim=EMBEDDING_DIM)\n","print (f\"<Embeddings(words={embedding_matrix.shape[0]}, dim={embedding_matrix.shape[1]})>\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<Embeddings(words=5000, dim=100)>\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"C26maF-9Goit"},"source":["# Experiments"]},{"cell_type":"markdown","metadata":{"id":"eTWQcUJ_GrIx"},"source":["We have first have to decice whether to use pretrained embeddings randomly initialized ones. Then, we can choose to freeze our embeddings or continue to train them using the supervised data (this could lead to overfitting). Here are the three experiments we're going to conduct: \n","* randomly initialized embeddings (fine-tuned)\n","* GloVe embeddings (frozen)\n","* GloVe embeddings (fine-tuned)"]},{"cell_type":"code","metadata":{"id":"geKOPVzVK6S9"},"source":["import json\n","from sklearn.metrics import precision_recall_fscore_support\n","from torch.optim import Adam"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"64iPmq2lDv2h"},"source":["NUM_FILTERS = 50\n","LEARNING_RATE = 1e-3\n","PATIENCE = 5\n","NUM_EPOCHS = 10"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"iIXt8XA09vYX"},"source":["class Trainer(object):\n","    def __init__(self, model, device, loss_fn=None, optimizer=None, scheduler=None):\n","\n","        # Set params\n","        self.model = model\n","        self.device = device\n","        self.loss_fn = loss_fn\n","        self.optimizer = optimizer\n","        self.scheduler = scheduler\n","\n","    def train_step(self, dataloader):\n","        \"\"\"Train step.\"\"\"\n","        # Set model to train mode\n","        self.model.train()\n","        loss = 0.0\n","\n","        # Iterate over train batches\n","        for i, batch in enumerate(dataloader):\n","\n","            # Step\n","            batch = [item.to(self.device) for item in batch]  # Set device\n","            inputs, targets = batch[:-1], batch[-1]\n","            self.optimizer.zero_grad()  # Reset gradients\n","            z = self.model(inputs)  # Forward pass\n","            J = self.loss_fn(z, targets)  # Define loss\n","            J.backward()  # Backward pass\n","            self.optimizer.step()  # Update weights\n","\n","            # Cumulative Metrics\n","            loss += (J.detach().item() - loss) / (i + 1)\n","\n","        return loss\n","\n","    def eval_step(self, dataloader):\n","        \"\"\"Validation or test step.\"\"\"\n","        # Set model to eval mode\n","        self.model.eval()\n","        loss = 0.0\n","        y_trues, y_probs = [], []\n","\n","        # Iterate over val batches\n","        with torch.no_grad():\n","            for i, batch in enumerate(dataloader):\n","\n","                # Step\n","                batch = [item.to(self.device) for item in batch]  # Set device\n","                inputs, y_true = batch[:-1], batch[-1]\n","                z = self.model(inputs)  # Forward pass\n","                J = self.loss_fn(z, y_true).item()\n","\n","                # Cumulative Metrics\n","                loss += (J - loss) / (i + 1)\n","\n","                # Store outputs\n","                y_prob = torch.sigmoid(z).cpu().numpy()\n","                y_probs.extend(y_prob)\n","                y_trues.extend(y_true.cpu().numpy())\n","\n","        return loss, np.vstack(y_trues), np.vstack(y_probs)\n","\n","    def predict_step(self, dataloader):\n","        \"\"\"Prediction step.\"\"\"\n","        # Set model to eval mode\n","        self.model.eval()\n","        y_probs = []\n","\n","        # Iterate over val batches\n","        with torch.no_grad():\n","            for i, batch in enumerate(dataloader):\n","\n","                # Forward pass w/ inputs\n","                inputs, targets = batch[:-1], batch[-1]\n","                y_prob = self.model(inputs, apply_softmax=True)\n","\n","                # Store outputs\n","                y_probs.extend(y_prob)\n","\n","        return np.vstack(y_probs)\n","    \n","    def train(self, num_epochs, patience, train_dataloader, val_dataloader):\n","        best_val_loss = np.inf\n","        for epoch in range(num_epochs):\n","            # Steps\n","            train_loss = self.train_step(dataloader=train_dataloader)\n","            val_loss, _, _ = self.eval_step(dataloader=val_dataloader)\n","            self.scheduler.step(val_loss)\n","\n","            # Early stopping\n","            if val_loss < best_val_loss:\n","                best_val_loss = val_loss\n","                best_model = self.model\n","                _patience = patience  # reset _patience\n","            else:\n","                _patience -= 1\n","            if not _patience:  # 0\n","                print(\"Stopping early!\")\n","                break\n","\n","            # Logging\n","            print(\n","                f\"Epoch: {epoch+1} | \"\n","                f\"train_loss: {train_loss:.5f}, \"\n","                f\"val_loss: {val_loss:.5f}, \"\n","                f\"lr: {self.optimizer.param_groups[0]['lr']:.2E}, \"\n","                f\"_patience: {_patience}\"\n","            )\n","        return best_model"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"Us7Smprz9cWO"},"source":["def get_performance(y_true, y_pred, classes):\n","    \"\"\"Per-class performance metrics.\"\"\"\n","    # Performance\n","    performance = {\"overall\": {}, \"class\": {}}\n","\n","    # Overall performance\n","    metrics = precision_recall_fscore_support(y_true, y_pred, average=\"weighted\")\n","    performance[\"overall\"][\"precision\"] = metrics[0]\n","    performance[\"overall\"][\"recall\"] = metrics[1]\n","    performance[\"overall\"][\"f1\"] = metrics[2]\n","    performance[\"overall\"][\"num_samples\"] = np.float64(len(y_true))\n","\n","    # Per-class performance\n","    metrics = precision_recall_fscore_support(y_true, y_pred, average=None)\n","    for i in range(len(classes)):\n","        performance[\"class\"][classes[i]] = {\n","            \"precision\": metrics[0][i],\n","            \"recall\": metrics[1][i],\n","            \"f1\": metrics[2][i],\n","            \"num_samples\": np.float64(metrics[3][i]),\n","        }\n","\n","    return performance"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"Y8JzMrcv_p8a"},"source":["### Randomly initialized embeddings"]},{"cell_type":"code","metadata":{"id":"TnLSYV0WKo8x"},"source":["PRETRAINED_EMBEDDINGS = None\n","FREEZE_EMBEDDINGS = False"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wD4sRUS5_lwq","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250579005,"user_tz":420,"elapsed":8166,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"682e80e9-7d7c-42e0-8880-d64532d6224f"},"source":["# Initialize model\n","model = CNN(\n","    embedding_dim=EMBEDDING_DIM, vocab_size=VOCAB_SIZE, \n","    num_filters=NUM_FILTERS, filter_sizes=FILTER_SIZES,\n","    hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES,\n","    pretrained_embeddings=PRETRAINED_EMBEDDINGS, freeze_embeddings=FREEZE_EMBEDDINGS)\n","model = model.to(device) # set device\n","print (model.named_parameters)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<bound method Module.named_parameters of CNN(\n","  (embeddings): Embedding(5000, 100, padding_idx=0)\n","  (conv): ModuleList(\n","    (0): Conv1d(100, 50, kernel_size=(1,), stride=(1,))\n","    (1): Conv1d(100, 50, kernel_size=(2,), stride=(1,))\n","    (2): Conv1d(100, 50, kernel_size=(3,), stride=(1,))\n","  )\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc1): Linear(in_features=150, out_features=100, bias=True)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")>\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"0uiqDFypJLU9"},"source":["# Define Loss\n","class_weights_tensor = torch.Tensor(list(class_weights.values())).to(device)\n","loss = nn.CrossEntropyLoss(weight=class_weights_tensor)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"BVLmJFYFJLXs"},"source":["# Define optimizer & scheduler\n","optimizer = Adam(model.parameters(), lr=LEARNING_RATE) \n","scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n","    optimizer, mode='min', factor=0.1, patience=3)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"CaV-uU0tJLbI"},"source":["# Trainer module\n","trainer = Trainer(\n","    model=model, device=device, loss_fn=loss_fn, \n","    optimizer=optimizer, scheduler=scheduler)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"2Ee8qqUnJLeg","executionInfo":{"status":"ok","timestamp":1608250621765,"user_tz":420,"elapsed":33687,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"b2502d32-3f6a-4537-80e8-6cdeb6c26915"},"source":["# Train\n","best_model = trainer.train(\n","    NUM_EPOCHS, PATIENCE, train_dataloader, val_dataloader)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Epoch: 1 | train_loss: 0.77038, val_loss: 0.59683, lr: 1.00E-03, _patience: 3\n","Epoch: 2 | train_loss: 0.49571, val_loss: 0.54363, lr: 1.00E-03, _patience: 3\n","Epoch: 3 | train_loss: 0.40796, val_loss: 0.54551, lr: 1.00E-03, _patience: 2\n","Epoch: 4 | train_loss: 0.34797, val_loss: 0.57950, lr: 1.00E-03, _patience: 1\n","Stopping early!\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"PRzw6CpqJRUA"},"source":["# Get predictions\n","test_loss, y_true, y_prob = trainer.eval_step(dataloader=test_dataloader)\n","y_pred = np.argmax(y_prob, axis=1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"qYUvB6FlJRW1","executionInfo":{"status":"ok","timestamp":1608250622444,"user_tz":420,"elapsed":17904,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"a9e8ff54-beb0-4d29-b24e-1c9a9b01abd7"},"source":["# Determine performance\n","performance = get_performance(\n","    y_true=y_test, y_pred=y_pred, classes=label_encoder.classes)\n","print (json.dumps(performance['overall'], indent=2))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{\n","  \"precision\": 0.8070310520771562,\n","  \"recall\": 0.7999444444444445,\n","  \"f1\": 0.8012357147662316,\n","  \"num_samples\": 18000.0\n","}\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"To_CB7ibLesP"},"source":["### GloVe embeddings (frozen)"]},{"cell_type":"code","metadata":{"id":"oT9w__AMkqfG"},"source":["PRETRAINED_EMBEDDINGS = embedding_matrix\n","FREEZE_EMBEDDINGS = True"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"yg13AyoUkqcJ","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250622446,"user_tz":420,"elapsed":14826,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"40e0fa9e-3cb5-413c-fd95-b7ed2ad6b1fc"},"source":["# Initialize model\n","model = CNN(\n","    embedding_dim=EMBEDDING_DIM, vocab_size=VOCAB_SIZE, \n","    num_filters=NUM_FILTERS, filter_sizes=FILTER_SIZES,\n","    hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES,\n","    pretrained_embeddings=PRETRAINED_EMBEDDINGS, freeze_embeddings=FREEZE_EMBEDDINGS)\n","model = model.to(device) # set device\n","print (model.named_parameters)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<bound method Module.named_parameters of CNN(\n","  (embeddings): Embedding(5000, 100, padding_idx=0)\n","  (conv): ModuleList(\n","    (0): Conv1d(100, 50, kernel_size=(1,), stride=(1,))\n","    (1): Conv1d(100, 50, kernel_size=(2,), stride=(1,))\n","    (2): Conv1d(100, 50, kernel_size=(3,), stride=(1,))\n","  )\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc1): Linear(in_features=150, out_features=100, bias=True)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")>\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"6rJNp4Vb-dqz"},"source":["# Define Loss\n","class_weights_tensor = torch.Tensor(list(class_weights.values())).to(device)\n","loss = nn.CrossEntropyLoss(weight=class_weights_tensor)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"RKtdPOdM-dt0"},"source":["# Define optimizer & scheduler\n","optimizer = Adam(model.parameters(), lr=LEARNING_RATE) \n","scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n","    optimizer, mode='min', factor=0.1, patience=3)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"MtVthG4r-dwy"},"source":["# Trainer module\n","trainer = Trainer(\n","    model=model, device=device, loss_fn=loss_fn, \n","    optimizer=optimizer, scheduler=scheduler)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"oy3FP3ht-gJY","executionInfo":{"status":"ok","timestamp":1608250652270,"user_tz":420,"elapsed":31857,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"998ce771-66c3-4b39-8362-e241129e00c3"},"source":["# Train\n","best_model = trainer.train(\n","    NUM_EPOCHS, PATIENCE, train_dataloader, val_dataloader)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Epoch: 1 | train_loss: 0.51510, val_loss: 0.47643, lr: 1.00E-03, _patience: 3\n","Epoch: 2 | train_loss: 0.44220, val_loss: 0.46124, lr: 1.00E-03, _patience: 3\n","Epoch: 3 | train_loss: 0.41204, val_loss: 0.46231, lr: 1.00E-03, _patience: 2\n","Epoch: 4 | train_loss: 0.38733, val_loss: 0.46606, lr: 1.00E-03, _patience: 1\n","Stopping early!\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"h2e0q965-gMK"},"source":["# Get predictions\n","test_loss, y_true, y_prob = trainer.eval_step(dataloader=test_dataloader)\n","y_pred = np.argmax(y_prob, axis=1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"qyNhVcHi-juw","executionInfo":{"status":"ok","timestamp":1608250652758,"user_tz":420,"elapsed":29176,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"b3cc6369-3d29-4810-df2b-4fe359f51417"},"source":["# Determine performance\n","performance = get_performance(\n","    y_true=y_test, y_pred=y_pred, classes=label_encoder.classes)\n","print (json.dumps(performance['overall'], indent=2))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{\n","  \"precision\": 0.8304874226557859,\n","  \"recall\": 0.8281111111111111,\n","  \"f1\": 0.828556487688813,\n","  \"num_samples\": 18000.0\n","}\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"dUVkeDbNqO7V"},"source":["### Fine-tuned GloVe embeddings (unfrozen)"]},{"cell_type":"code","metadata":{"id":"eubLrHydkt_J"},"source":["PRETRAINED_EMBEDDINGS = embedding_matrix\n","FREEZE_EMBEDDINGS = False"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"IGeZwoy9qUpa","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1608250652759,"user_tz":420,"elapsed":27850,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"56db02ee-5401-4875-a0ef-6b938e8bc9f6"},"source":["# Initialize model\n","model = CNN(\n","    embedding_dim=EMBEDDING_DIM, vocab_size=VOCAB_SIZE, \n","    num_filters=NUM_FILTERS, filter_sizes=FILTER_SIZES,\n","    hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES,\n","    pretrained_embeddings=PRETRAINED_EMBEDDINGS, freeze_embeddings=FREEZE_EMBEDDINGS)\n","model = model.to(device) # set device\n","print (model.named_parameters)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["<bound method Module.named_parameters of CNN(\n","  (embeddings): Embedding(5000, 100, padding_idx=0)\n","  (conv): ModuleList(\n","    (0): Conv1d(100, 50, kernel_size=(1,), stride=(1,))\n","    (1): Conv1d(100, 50, kernel_size=(2,), stride=(1,))\n","    (2): Conv1d(100, 50, kernel_size=(3,), stride=(1,))\n","  )\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc1): Linear(in_features=150, out_features=100, bias=True)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")>\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"ifqXyPZ1JKWY"},"source":["# Define Loss\n","class_weights_tensor = torch.Tensor(list(class_weights.values())).to(device)\n","loss = nn.CrossEntropyLoss(weight=class_weights_tensor)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"kXGrQ0ceJKZk"},"source":["# Define optimizer & scheduler\n","optimizer = Adam(model.parameters(), lr=LEARNING_RATE) \n","scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n","    optimizer, mode='min', factor=0.1, patience=3)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"IinLK_ohJKdr"},"source":["# Trainer module\n","trainer = Trainer(\n","    model=model, device=device, loss_fn=loss_fn, \n","    optimizer=optimizer, scheduler=scheduler)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"tpVOifjMJKgx","executionInfo":{"status":"ok","timestamp":1608250686423,"user_tz":420,"elapsed":60458,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"f281d7dc-2196-4c50-bf1b-8877c20db9ad"},"source":["# Train\n","best_model = trainer.train(\n","    NUM_EPOCHS, PATIENCE, train_dataloader, val_dataloader)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Epoch: 1 | train_loss: 0.48908, val_loss: 0.44320, lr: 1.00E-03, _patience: 3\n","Epoch: 2 | train_loss: 0.38986, val_loss: 0.43616, lr: 1.00E-03, _patience: 3\n","Epoch: 3 | train_loss: 0.34403, val_loss: 0.45240, lr: 1.00E-03, _patience: 2\n","Epoch: 4 | train_loss: 0.30224, val_loss: 0.49063, lr: 1.00E-03, _patience: 1\n","Stopping early!\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"TJfbNqp2JQgT"},"source":["# Get predictions\n","test_loss, y_true, y_prob = trainer.eval_step(dataloader=test_dataloader)\n","y_pred = np.argmax(y_prob, axis=1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"thdmnUOTJQld","executionInfo":{"status":"ok","timestamp":1608250686999,"user_tz":420,"elapsed":59824,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"19b75fb4-a64f-45c0-c142-2e33338580bf"},"source":["# Determine performance\n","performance = get_performance(\n","    y_true=y_test, y_pred=y_pred, classes=label_encoder.classes)\n","print (json.dumps(performance['overall'], indent=2))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{\n","  \"precision\": 0.8297157849772082,\n","  \"recall\": 0.8263333333333334,\n","  \"f1\": 0.8266579939871359,\n","  \"num_samples\": 18000.0\n","}\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"5R-df_DMY51A"},"source":["# Save artifacts\n","from pathlib import Path\n","dir = Path(\"cnn\")\n","dir.mkdir(parents=True, exist_ok=True)\n","label_encoder.save(fp=Path(dir, 'label_encoder.json'))\n","tokenizer.save(fp=Path(dir, 'tokenizer.json'))\n","torch.save(best_model.state_dict(), Path(dir, 'model.pt'))\n","with open(Path(dir, 'performance.json'), \"w\") as fp:\n","    json.dump(performance, indent=2, sort_keys=False, fp=fp)"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"xKdgLetZKhEj"},"source":["# Inference"]},{"cell_type":"code","metadata":{"id":"zPWRyqBoKks0"},"source":["def get_probability_distribution(y_prob, classes):\n","    \"\"\"Create a dict of class probabilities from an array.\"\"\"\n","    results = {}\n","    for i, class_ in enumerate(classes):\n","        results[class_] = np.float64(y_prob[i])\n","    sorted_results = {k: v for k, v in sorted(\n","        results.items(), key=lambda item: item[1], reverse=True)}\n","    return sorted_results"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"1zg0ErQMY4cZ","executionInfo":{"status":"ok","timestamp":1608254328349,"user_tz":420,"elapsed":773,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"a5925f8f-b8ca-4b38-c4f7-9ab446890378"},"source":["# Load artifacts\n","device = torch.device(\"cpu\")\n","label_encoder = LabelEncoder.load(fp=Path(dir, 'label_encoder.json'))\n","tokenizer = Tokenizer.load(fp=Path(dir, 'tokenizer.json'))\n","model = CNN(\n","    embedding_dim=EMBEDDING_DIM, vocab_size=VOCAB_SIZE, \n","    num_filters=NUM_FILTERS, filter_sizes=FILTER_SIZES,\n","    hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES,\n","    pretrained_embeddings=PRETRAINED_EMBEDDINGS, freeze_embeddings=FREEZE_EMBEDDINGS)\n","model.load_state_dict(torch.load(Path(dir, 'model.pt'), map_location=device))\n","model.to(device)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["CNN(\n","  (embeddings): Embedding(5000, 100, padding_idx=0)\n","  (conv): ModuleList(\n","    (0): Conv1d(100, 50, kernel_size=(1,), stride=(1,))\n","    (1): Conv1d(100, 50, kernel_size=(2,), stride=(1,))\n","    (2): Conv1d(100, 50, kernel_size=(3,), stride=(1,))\n","  )\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc1): Linear(in_features=150, out_features=100, bias=True)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")"]},"metadata":{"tags":[]},"execution_count":203}]},{"cell_type":"code","metadata":{"id":"Bviv-K-FY4gS"},"source":["# Initialize trainer\n","trainer = Trainer(model=model, device=device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"cDD44HKfY4jY","executionInfo":{"status":"ok","timestamp":1608255720107,"user_tz":420,"elapsed":386,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"546f31f7-8b4b-4004-b1d4-90fe4f1b03b8"},"source":["# Dataloader\n","text = \"The final tennis tournament starts next week.\"\n","X = tokenizer.texts_to_sequences([preprocess(text)])\n","print (tokenizer.sequences_to_texts(X))\n","y_filler = label_encoder.encode([label_encoder.classes[0]]*len(X))\n","dataset = Dataset(X=X, y=y_filler, max_filter_size=max_filter_size)\n","dataloader = dataset.create_dataloader(batch_size=batch_size)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["['final tennis tournament starts next week']\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"EXrACSa6ZJgb","executionInfo":{"status":"ok","timestamp":1608255720870,"user_tz":420,"elapsed":583,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"858e4b78-f00c-4ee4-e3b0-15f69e54d03c"},"source":["# Inference\n","y_prob = trainer.predict_step(dataloader)\n","y_pred = np.argmax(y_prob, axis=1)\n","label_encoder.decode(y_pred)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["['Sports']"]},"metadata":{"tags":[]},"execution_count":255}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Tbz0ZnYSZJkD","executionInfo":{"status":"ok","timestamp":1608255721724,"user_tz":420,"elapsed":939,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"ea77c0bc-430e-420a-e6b7-b26755a2a4eb"},"source":["# Class distributions\n","prob_dist = get_probability_distribution(y_prob=y_prob[0], classes=label_encoder.classes)\n","print (json.dumps(prob_dist, indent=2))"],"execution_count":null,"outputs":[{"output_type":"stream","text":["{\n","  \"Sports\": 0.9999998807907104,\n","  \"World\": 6.336378532978415e-08,\n","  \"Sci/Tech\": 2.107449992294619e-09,\n","  \"Business\": 3.706519813295728e-10\n","}\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"zXZtx3nsKlDr"},"source":["# Interpretability"]},{"cell_type":"markdown","metadata":{"id":"4KImhaLkcFuJ"},"source":["We went through all the trouble of padding our inputs before convolution to result is outputs of the same shape as our inputs so we can try to get some interpretability. Since every token is mapped to a convolutional output on which we apply max pooling, we can see which token's output was most influential towards the prediction. We first need to get the conv outputs from our model:"]},{"cell_type":"code","metadata":{"id":"RnHVnI8hdxYC"},"source":["import collections\n","import seaborn as sns"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"atX4qOs1Kl1Y"},"source":["class InterpretableCNN(nn.Module):\n","    def __init__(self, embedding_dim, vocab_size, num_filters, \n","                 filter_sizes, hidden_dim, dropout_p, num_classes, \n","                 pretrained_embeddings=None, freeze_embeddings=False,\n","                 padding_idx=0):\n","        super(InterpretableCNN, self).__init__()\n","\n","        # Filter sizes\n","        self.filter_sizes = filter_sizes\n","        \n","        # Initialize embeddings\n","        if pretrained_embeddings is None:\n","            self.embeddings = nn.Embedding(\n","                embedding_dim=embedding_dim, num_embeddings=vocab_size,\n","                padding_idx=padding_idx)\n","        else:\n","            pretrained_embeddings = torch.from_numpy(pretrained_embeddings).float()\n","            self.embeddings = nn.Embedding(\n","                embedding_dim=embedding_dim, num_embeddings=vocab_size,\n","                padding_idx=padding_idx, _weight=pretrained_embeddings)\n","        \n","        # Freeze embeddings or not\n","        if freeze_embeddings:\n","            self.embeddings.weight.requires_grad = False\n","        \n","        # Conv weights\n","        self.conv = nn.ModuleList(\n","            [nn.Conv1d(in_channels=embedding_dim, \n","                       out_channels=num_filters, \n","                       kernel_size=f) for f in filter_sizes])\n","     \n","        # FC weights\n","        self.dropout = nn.Dropout(dropout_p)\n","        self.fc1 = nn.Linear(num_filters*len(filter_sizes), hidden_dim)\n","        self.fc2 = nn.Linear(hidden_dim, num_classes)\n","\n","    def forward(self, inputs, channel_first=False, apply_softmax=False):\n","        \n","        # Embed\n","        x_in, = inputs\n","        x_in = self.embeddings(x_in)\n","\n","        # Rearrange input so num_channels is in dim 1 (N, C, L)\n","        if not channel_first:\n","            x_in = x_in.transpose(1, 2)\n","            \n","        # Conv outputs\n","        z = []\n","        max_seq_len = x_in.shape[2]\n","        for i, f in enumerate(self.filter_sizes):\n","            # `SAME` padding\n","            padding_left = int((self.conv[i].stride[0]*(max_seq_len-1) - max_seq_len + self.filter_sizes[i])/2)\n","            padding_right = int(math.ceil((self.conv[i].stride[0]*(max_seq_len-1) - max_seq_len + self.filter_sizes[i])/2))\n","\n","            # Conv + pool\n","            _z = self.conv[i](F.pad(x_in, (padding_left, padding_right)))\n","            z.append(_z.cpu().numpy())\n","        \n","        return z"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"wybIldYFctMF"},"source":["PRETRAINED_EMBEDDINGS = embedding_matrix\n","FREEZE_EMBEDDINGS = False"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"QLNK4ez2ctPB","executionInfo":{"status":"ok","timestamp":1608255723908,"user_tz":420,"elapsed":679,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"c415d468-67d8-456e-91a6-4123e88e2f7a"},"source":["# Initialize model\n","interpretable_model = InterpretableCNN(\n","    embedding_dim=EMBEDDING_DIM, vocab_size=VOCAB_SIZE, \n","    num_filters=NUM_FILTERS, filter_sizes=FILTER_SIZES,\n","    hidden_dim=HIDDEN_DIM, dropout_p=DROPOUT_P, num_classes=NUM_CLASSES,\n","    pretrained_embeddings=PRETRAINED_EMBEDDINGS, freeze_embeddings=FREEZE_EMBEDDINGS)\n","interpretable_model.load_state_dict(torch.load(Path(dir, 'model.pt'), map_location=device))\n","interpretable_model.to(device)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["InterpretableCNN(\n","  (embeddings): Embedding(5000, 100, padding_idx=0)\n","  (conv): ModuleList(\n","    (0): Conv1d(100, 50, kernel_size=(1,), stride=(1,))\n","    (1): Conv1d(100, 50, kernel_size=(2,), stride=(1,))\n","    (2): Conv1d(100, 50, kernel_size=(3,), stride=(1,))\n","  )\n","  (dropout): Dropout(p=0.1, inplace=False)\n","  (fc1): Linear(in_features=150, out_features=100, bias=True)\n","  (fc2): Linear(in_features=100, out_features=4, bias=True)\n",")"]},"metadata":{"tags":[]},"execution_count":260}]},{"cell_type":"code","metadata":{"id":"SGdju-O6dEwW"},"source":["# Initialize trainer\n","interpretable_trainer = Trainer(model=interpretable_model, device=device)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0NJYmV6idE1v","executionInfo":{"status":"ok","timestamp":1608255724804,"user_tz":420,"elapsed":407,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"5c9f507e-dd04-4ebc-fd8b-2157ae558507"},"source":["# Get conv outputs\n","conv_outputs = interpretable_trainer.predict_step(dataloader)\n","print (conv_outputs.shape) # (len(filter_sizes), num_filters, max_seq_len)"],"execution_count":null,"outputs":[{"output_type":"stream","text":["(3, 50, 6)\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/","height":333},"id":"w6jWfdK7dE43","executionInfo":{"status":"ok","timestamp":1608255725564,"user_tz":420,"elapsed":880,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"8aaed942-a4db-416d-86dc-6526cc6e3d55"},"source":["# Visualize a bi-gram filter's outputs\n","tokens = tokenizer.sequences_to_texts(X)[0].split(' ')\n","sns.heatmap(conv_outputs[1], xticklabels=tokens)"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["<matplotlib.axes._subplots.AxesSubplot at 0x7f30b2286710>"]},"metadata":{"tags":[]},"execution_count":263},{"output_type":"display_data","data":{"image/png":"iVBORw0KGgoAAAANSUhEUgAAAVwAAAErCAYAAACSMTtVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deZhcVZ3/8fcnC0kgGwqEJciO7LIElGEGFBCDOqgIissI6JhRH3EdFQcf9x0RGf2pRAVGZQZHEVE2B2RHkQRkBxVwAQxLEEI2knT39/fHvQ2Vpqtv3e6quvdUPi+f+9h9q2/VJ6Fz6tS553yPIgIzM+u8cVUHMDNbV7jBNTPrEje4ZmZd4gbXzKxL3OCamXWJG1wzsy5xg2tm1iUTin5A0k7Aq4At8lMPAj+PiLs6GczMrNdopIUPkj4CvAE4B3ggPz0bOAY4JyK+WPQCp89+c1IrK970phVVRyht/A5bVx2htM999qGqI5SS1C9xbvs16X2APf7BH2qsz7Fm8X0t/+eauNG2Y369Mop6uG8Ddo2INY0nJX0VuAMobHDNzLpqoL/qBE0VvQUOAJsPc36z/LFhSZonaaGkhdcs/+NY8pmZlRMDrR9dVtTDfR/wK0l/BO7Pzz0P2B54d7OLImI+MB/gpK3fGH9t3jbXzoXfn1J1hNImR1ofzwGempzWh/St+sdXHaG03VhedYRqDNS3vRmxwY2ISyTtCOzH2jfNFkREffvtZrbOiv6+tjyPpC2B7wOzyIbx50fEaWN5zsJZChExAFw/lhcxM+ua9g0V9AEfjIibJE0DbpR0aUTcOdonLGxwzcyS0qabZhGxCFiUf71U0l1kn/RH3eCmN2/EzGwkJW6aNd7gz495wz2lpK2BvYDfjiVax3u4e67q6jS3MXueVlYdobTHB9arOkJpr1yZ1i2A+ZOfqjpCaWKDqiOUtm87nqTETbPGG/zNSJoKnAu8LyKeHEs0DymYWU9p100zAEkTyRrbsyPip2N9Pje4ZtZb2nTTTJKA7wF3RcRX2/GcHsM1s94y0N/6MbIDgH8BDpZ0c368fCzRWilesx8QEbFA0i7AXODuiLiolReYFavHkq/rlg+k1+mfs/eiqiOUtmZpWu/1H38kvQUxfX1rin+oF7WphxsR1wJtvQk1Yusi6RPA4cAESZcCLwSuAE6UtFdEfK6dYczMxizVlWbAUcCewCTgIWB2RDwp6Stk0yPc4JpZvVRQI6FVRZ/r+iKiPyJWAPcOTomIiJW0WLzm5yvua2NcM7ORRf+alo9uK2pwV0taP/96n8GTkmYwQoMbEfMjYk5EzDli/W3bENPMrEUJVws7MCJWwdM1FQZNBI5t5QV22/+RUUarxqTdZlUdoTRtvkfVEUpbfe0dVUcoZaA/vYUPSx+bXHWEaqQ6hjvY2A5zfjGwuCOJzMzGosZjuOnNgTIzG0mNd3xwg2tmvaWNS3vbreMN7tELJnX6Jdpqq9/V9+NIM4v676k6QmlHxnA7N9XXc/vS+71I0U7teBIPKZiZdUmqN83MzJLjBtfMrDvqvN1ixxvcef0bd/ol2uq2CWntJgvwutXpvW9ust7SqiOU8iDrF/9QzWxQ44ano1Lt4Up6IVktyCclTQFOBPYm29Pn8xGxpAsZzcxaV+NZCkVLe88AVuRfnwbMAL6Unzuzg7nMzEYn4aW94yJi8O1iTkTsnX99raSbm12Ub8Q2D+Bfp+/HoetvP/akZmatqPGQQlEP93ZJx+df3yJpDoCkHYGmpXYai9e4sTWzrkq4h/uvwGmSPkZWO+E3ku4H7s8f6znbr0lrJwKAe9ZLL/P41WkVVrliSno3U3fsT28354Pb8SQ17uEWFa9ZAhwnaTqwTf7zD0TEw90IZ2ZWWqoN7qC88PgtHc5iZjZ2NZ6lkN4ETjOzkazLtRQOnfNAp1+irSbvtVnVEUq78vS2bizaFbOmrCj+oRo5YUZ6BcgXPTK96gjVSH1IwcwsGetyD9fMrKvcwzUz65L++taQ6HiDO2mHGZ1+iba67burq45Q2l/WS6+wykN9af1eXLskvTmtH5ic3u9yW6Tcw5W0LXAksCXQD/wB+O98qpiZWb3UuMEdcYmSpPcA3wYmA/sCk8ga3uslvbjj6czMyqrx0t6iNaFvBw6PiM8ChwK7RsRJwFzg1GYXSZonaaGkhWfc+pf2pTUzKzIw0PrRZa0swh8cdpgETAWIiL8CE5td0Fi85q17bDX2lGZmrYpo/Sgg6QxJj0i6vR3RisZwvwsskPRb4J/IauEiaWPg7628wOPXLB9TwG7b46Sdq45Q2q533Ft1hNKuOTetm2b/wJSqI5R2Y6SXefd2PElfW5f2ngV8A/h+O56sqHjNaZIuA3YGTomIu/PzjwIHtiOAmVlbtXFsNiKulrR1u56vcJZCRNwB3NGuFzQz66QYaL2UZuNmCbn5ETG/7aFyXvhgZr2lxM2wvHHtWAM7VMcb3JMf3ajTL9FWS09eVHWE0naO51YdobRJk6pOUM6q9OoDMTG9munt4VoKZmZdUmJIodvS25vFzGwkfX2tHwUk/Q/wG+D5kh6Q9LaxRHMP18x6Swvza1t/qnhD256MLjS4Iq3Brw2br+eorT8rvSIlT1LfbVCGs3Dl/VVHKO2r2q7qCNWocS0F93DNrLfUeAzXDa6Z9ZYaz1IoqhY2XdIXJP1A0huHPPbNEa57unjN7UvTW3ZqZumKvv6Wj24rmqVwJiDgXOAYSedKGpxB+aJmFzUWr9lt2jo6jmRm1RiI1o8uKxpS2C4iXpt//TNJJwGXSzqi1Re4Yc3Dow5Xhb0mblx1hNImJHZjEmDeqvFVRyjlvRNnVx2htCf66juW2VE1HlIoanAnSRoXkf0JIuJzkh4EriYv1WhmVis1vmlWNKTwC+DgxhMRcRbwQSC9uUhm1vtqXIC8qDzjh5ucv0TS5zsTycxsDGrcwx3LtLBPkd1UG9H5O9V3PGU4N92e3mrnFUov85/Hp7XAZOfEFmoAzJywjn4ITXWbdEm3NnsImNX+OGZmYxMJrzSbBbwMeHzIeQG/7kgiM7OxSHhI4QJgakTcPPQBSVd2JJGZ2Vik2uBGRNNSZBHxxmaPNZr4nLTmiO64SUt7Y9bKokenVx2htHsmTK46QikLtV7VEUqbsyatcXKAfdvxJAnPwzUzS0uqPVwzs9REX317uEXFa+Y2fD1D0vck3SrpvyU1naXQWLzmrPv+1s68ZmYjq/HCh6IJnI2LG04BFgH/DCwATm92UWPxmuO23XzsKc3MWpVw8ZpGcyJiz/zrUyUd28pFf78rre1ZN//w3lVHKG2z56RXcGf2xy+qOkIpK5eld9NsRX96N83aIuEx3E0kfYBs3u10SYp4esOg9JY3mVnPizbuadZuRQ3ud4Bp+df/BWwEPCppU+BZc3PNzCpX45tmRfNwP9Xk/EOSruhMJDOz0YuEhxRG0lLxmjWr0yo0PeHlb686QmlL/uX4qiOUds+Dm1YdoZTHx6U3g3K3aUNX5K8jUm1wXbzGzJJT3xEFF68xs96S8pCCi9eYWVpSbXDbUbzm3FUbls1Uqcl7f7zqCKW9enpahWCg1p/6hrU5T1UdobT7l0wr/qGa2akNzxE13jwzvTsBZmYjqfG7uRtcM+spdR7DLSpeM0fSFZJ+KGlLSZdKWiJpgaS9Rrju6eI1C5bd0/7UZmbNDJQ4CkiaK+n3ku6RdOJYoxUtz/0m8GXgQrJZCadHxAzgxPyxYTUWr9l36vZjzWhm1rIYaP0YiaTxwP8DDgd2Ad4gaZexZCsaUpgYERfnL/6liPgJQET8StJXWnmB3VbVt3s/nH/Y/YGqI5S2/NH0CqvcNzGtwipPjUsrL8DKtDZbAeClbXiOaN8Gy/sB90TEfQCSzgFeBdw52ics6uE+JekwSUcDIenV+QsfBNR3L2IzW3eVGFJoHP7Mj3kNz7QFcH/D9w/k50atqIf7DrIhhQGyBRDvlHQW8CCQ3hpYM+t5ZbY0i4j5wPyOhRlixB5uRNwSES+LiMMj4u6IeG9EzIyIXYHndymjmVnL2jWGS9ax3LLh+9n5uVHrePGag496Ygwv0X2/PHd21RFKe3x8eoN1Z8eiqiOUcuC49EqHHP7U6qojVKKNm/YuAHaQtA1ZQ3sM0NKCr2ZcvMbMeku0pwMSEX2S3g38EhgPnBERd4zlOV28xsx6ykBf+z7xRcRFQNv2g3LxGjPrKW0cUmi7jhevWfXH5WUzVeo5A1OrjlDaGqU3R/RDA5tUHaGUX05MbxbkIqW1gWu7RJuGFDrBtRTMrKfUuYc76p13JV3cziBmZu0QA2r56LaiWQp7N3sI2HOE6+YB8wBO2WUH3jJ7s1EHNDMro8a7pBcOKSwAriJrYIea2eyixtUbi192UI3/+GbWawb6Rv3BveOKGty7gH+LiD8OfUDS/cP8/LMseTCt3Qgun5LesPZ2a+p7k6CZVePT2s35eZFWXoDdZzxWdYRKpNzD/STNx3lPaG8UM7Oxq2JstlVFtRR+AkjSIZKGzpdKb5MnM+t5EWr56LaiHR/eA5xP1pu9XdKrGh7+fCeDmZmNRhuL17Rd0ZDC24F9ImKZpK2Bn0jaOiJOY/gbac9ywZMbjy1hl83bJK2iKgBfWbxR1RFKW1nnnf6GsTLSW/hw8LJ1c+FD/0C6N83GRcQygIj4s6QXkzW6W9Fig2tm1k3JjuECD0t6er5t3vi+EtgI2L2TwczMRiOi9aPbinq4bwHW2iEoIvqAt0g6vWOpzMxGqc493KLiNU13VIyI61p5gZ1XtW9Ht264+/6NuXVSWnNx3zX571VHKO2JZWnNz971dWuqjlDa5edMqzpCaXPa8BwDvVS8RtImEfFIJ8LUQWqNrZmtLdlqYZKeM/QUcIOkvQBFRHpdKzPraf2pDikAi4G/DDm3BXATEMC2w13UWLzmvdPm8PIp240xpplZa+rcwy2apfAh4PfAERGxTURsAzyQfz1sYwtZ8ZqImBMRc9zYmlk3JTtLISJOkfQj4NS8WM0nyHq2LfvHD0wZQ7zuO+iVb646Qmn9l59bdYTSJp5+b9URSrnynKGja/X34MT6LgDopKRvmuUzFY6WdARwKbB+x1OZmY1SnYcUChtcSTuRjdteTtbgbpefnxsRl3Q2nplZOXXu4ZYqXgMcFhG35w+7eI2Z1U5/qOWj2zpevOZT30qriuOUb55RdYTSlimtQjAAL1vZdMOQWlowJb3x0Kk1LsTdSSkPKbh4jZklpc7dDxevMbOeEqjlo9tcvMbMespAjYdSOl68xsysm/oLP7hXp+OVWl62Mq1K+bdMSm931hc9VeO39CYmJXajb+fVVScob+7+TftLPa1bv1mSjibbaHdnYL+IWFh0TdG0sJskfUyS1+eaWRK6OIZ7O3AkcHWrFxT1vTcEZgJXSLpB0vslbV70pJLmSVooaeEFK+9rNYuZ2ZgNlDjGIiLuiojfl7mmqMF9PCL+PSKeB3wQ2AG4SdIVeUWwZkGeLl7zyilNa9yYmbVdtxrc0Wh5DDcirgGukXQC8FLg9cD8outW13gAezhHbvRQ1RFKGzc+vTHcqx7crOoIpfxg/GNVRyhti+vS2jEb4MA2PEeZoYLGUrK5+RExv+Hxy4BNh7n0pIg4v2y2ogb3D0NPREQ/cEl+mJnVSp9ab3DzxrVpxzEiDm1HpkEjdj8j4hhJO0k6RNLUxsckzW1nEDOzdogSR7cVzVI4gYbiNZJe1fCwi9eYWe10awxX0mskPQDsD1wo6ZdF1xQNKcxjjMVr9p2zqJUfq43H7kurYDrA4ic2qDpCaXtOeaLqCKUcts2KqiOU9uSitOY6t8tAiSGFsYiI84Dzylzj4jVm1lPqfAvZxWvMrKekPC3MxWvMLCllZil0m4vXmFlPqfOQQseL1/zjgpWdfom26suGrJPymg02rDpCaZv1pXVz8tZ7J1UdobRpnf/n3Xb/2YbnGKhvB3fk/yKSJgBvA14DDNZQeJBsqtj3ImJNZ+OZmZVT57kZRW+BPwCeICtBNji8MBs4Fvgh2fLeZ2lcLrfp1K2YOWWTdmQ1MyuU8pDCPhGx45BzDwDXS3rWst9Bjcvldt5kvzr/+c2sx/SlOqQA/D0vsntuRAwASBoHHA083soL3HDUc8eWsMsm/PPhVUcoLf5wZ9URSvvmKWmNla+XWBEmgPdOTWtxSbvUeUih6LfoGOAo4CFJf8h7tQ+RFd09ptPhzMzKCrV+dFvRtLA/S/oqcApwL7AT2brhOyPiT13IZ2ZWSp17uEWzFD4BHJ7/3KXAfsCVwImS9oqIz3U8oZlZCck2uGTDCXsCk8iGEmZHxJOSvgL8FihscC//aVpzRM+/qOXtiWpjYoJlLXYkrXmtS6Ov+Idq5pbFad0/AWjH5ol1vktf1OD25QXHV0i6NyKeBIiIlVJi266a2Toh5VkKqyWtHxErgH0GT0qaQb177ma2jqpzw1TU4B4YEasABqeF5SaSLX4wM6uVZIcUBhvbYc4vBhZ3JJGZ2RgkW0uhHY76+1Wdfom2mjRhvaojlHbIRrtWHaG0iePSWkgQUed+0/CmRX/VESqR7JCCpPWBd5P10r9OttjhSOBu4NODu0GYmdVFnd8ai7oZZwGzgG2AC4E5wMlk2+t8q9lFkuZJWihp4cDA8jZFNTMr1ke0fHRb0ZDCjhHxOkkCFgGHRkRIuha4pdlFjcVrJq63RZ3fcMysx9S5wWlpDDdvZC+KfCAr/76lP9cvNvynseTrur13S2uXYYAnF6VXlnjBE2kVx959cnqf1JaNT2ucvF2SHcMFFkqaGhHLIuKtgyclbQcs7Ww0M7Pykp2lEBH/Kmk/SRERCyTtAswFfg+k1XU1s3XCQI0HFVouXiPpUuCFwBXAR8hqLLh4jZnVSp0nw3W8eM3vJqc1Vrfe7bOqjlDaJVPGVx2htM0m1Phz3zCWrJlZdYTSlOAQ7ova8BzJ9nBx8RozS0x9m1sXrzGzHlPnhsnFa8ysp9R5SGHEUZ6RitdExG2diWRmNnpR4hgLSSdLulvSrZLOk1Q40N/xO1qH82SnX6KtZu+9pOoIpf3tli2rjlDaVgNPVR2hlP4qdhy0UenvXg/3UuCjEdEn6UvAR8lmcDU1Yg9X0rslbZR/vb2kqyU9Iem3knZvW2wzszYZKHGMRUT8X8TTey9dD8wuuqZo4sg789q3AKcBp0bETLJW/NvNLmosXnPusr+0EN3MrD0GiJaPNnorcHHRDxUNKTQ+vklEnAcQEVdKmtbsosbiNTdvdUR9R7DNrOeUaXAkzQPmNZyan7dfg49fBmw6zKUnRcT5+c+cBPQBZxe9XlGD+xNJZwGfBs6T9D7gPOBg4K9FTw4QiY19TZiZ3mzxubvfX3WE0u69La0dZadvkNaYM8Dfl65fdYRKlOm5NnYOmzx+6EjXSzoOeCVwyGBxr5EU1VI4KX/C/yHbwXgS2bvBz4A3FT25mVm3deummaS5wIeBg/K1CoVamaVwJ/DuvHjNrmTFa+6KiPRu55tZz+viwodvkHVCL81KhnN9RLxjpAvKFq/ZD7gSOFHSXhHh4jVmVivRpR5uRGxf9pqOF695fNWkspkqtU3VAUbhb7+fXnWE0q6ZkNb44rLVU6qOUNomE9O6fwJwQBueI+WlvS5eY2ZJGajxDstFt+RX5zv3govXmFkCurW0dzRcvMbMekp/jfuCRdPCmhavARYP95iZWZXq29x2oXjND6bUecOLZ/vur9Or7D9Lad2YBDhgTX3H2YazYFKd/xkPb4D0dgJphzqXZyyaFjYOOA54LVlhhn7gD8C3I+LKToczMyurW9PCRqPoptn3gOcBXyDbPPKC/NzHJJ3Q7KLG4jW/X/qntoU1MyvSrWpho1HU4O4TEZ+MiGsj4n3AYRFxKfAK4F3NLoqI+RExJyLmPH9aijNbzSxVEdHy0W1FY7hrJG0XEfdK2htYDdnNNEktpf3B364fa8au2nDK1KojlPaymbtWHaG0WZFWkaDZ/ZOrjlDa0rT+itumr8ZDCkUN7oeAKyStyn/2GABJG5MNL5iZ1Uqdx3CLpoVdLun1ZCvOFkjaRdIHgLsj4sPdiWhm1rqUZym4eI2ZJaWKsdlWdbx4zds3b0c5iu6Z07de1RFKe/G0R6uOUNrDq9MbK0/Nk+touZM6/6ldvMbMekqyS3vJi9fk1cxdvMbMai/lIQUXrzGzpCR708zFa8wsNclOC2uHj270WKdfoq02mNVXdYTS7rv5OVVHKO2xSOvm5LIEFxFs35dg6DZItgC5pPGS/k3SZyQdMOSxj3U2mplZeXUuQF70Fng6cBDwGPCfkr7a8NiRzS5qLF7z348+2IaYZmat6WOg5aPbihrc/SLijRHxNeCFwFRJP5U0CWi6Q11j8Zo3brxFO/OamY0o5eI1Tw+0RUQfMC9ffXY50NLM9Y2OSata2ISjRtxWvpa2eecHq45Q2j23zK46Qilv3fb+qiOUtuSh9HYaboc6z1Io6uEulDS38UREfAo4E9i6U6HMzEYrSvyv24qmhb156DlJ34+ItwDf7VgqM7NRSnbhg6SfDz0FvETSTICIOKJTwczMRqPOQwpFY7hbAneQ9WaDrMGdA5zS6gtccfLyUYerwkNf+1rVEUp7aHxa46EA+/SntbnoPXdtXHWE0q6clN7moie14Tn6o75VBwq32AFuJPt7WJJvHLkyIq6KiKs6Hc7MrKyUx3AHgFMl/Tj//4eLrjEzq1KdV5q11HhGxAPA0ZJeATzZ2UhmZqPXrZ6rpM8AryKrnPgIcFxE/G2ka0otto6ICyPiP0Yf0cysswYiWj7G6OSI2CMi9iTb4/HjRRd0fHjgwNcv7fRLtNWS366oOkJp0/dIb5Rn0TXjq45QyhXLn1t1hNJOOOyRqiNUols3zQY3ZMhtQAvlGdL7l2pmNoJu3gyT9DngLcAS4CVFP19ULWyPhq8nSvqYpJ9L+ryk9Ue47uniNWfe/tcS8c3MxqbMkEJjW5Uf8xqfS9Jlkm4f5ngVQEScFBFbAmcD7y7KVtTDPQvYO//6i8Bzyebgvhr4NlnL/iwRMR+YD7D0Pa+s7y1DM+s5ZXq4jW1Vk8cPbfGpzgYuAj4x0g8VNbiNFcEOAfaNiDWSrgZuaSXFJT+e0cqP1casSK/gx46T0tu1d9P90xrN2vTi+k6mb+bcizapOkJpx7fhOaJLY7iSdoiIP+bfvgq4u+iaot/6GZJeQzb0MCki1gBEREhyz9XMaqeLS3u/KOn5ZNPC/gIUlhosanCvBgbrJVwvaVZEPCxpU7ynmZnVUBdnKby27DVFK82OG3quoVrYIWVfzMys03qpWhjAwa4WZmZ1lfLS3uGqhe1LiWphrzx569Fmq0T/bXdVHaG0xb+cWHWE0v5w8cyqI5Ty0nfU9x9xM384I61FR+1S523SXS3MzHpKsnuauVqYmaUm5QLkgKuFmVk6+gfqO2e6VG81Ii4ELixzzQUf+nOZH6/cYXPTez9ZuXy94h+qmfE17oUM54Fznqg6QmkTxq+bH0brPEuhqJbCtpLOkPRZSVMlfSdfR/xjSVt3J6KZWesGiJaPbiu6aXYWsABYBlxPtnTtcOAS4IyOJjMzG4U63zQranCnRcS3IuKLwPSIOCUi7o+I7wEbNruosQLPZSvuaWtgM7ORdLEAeWlFgzwDknYEZgDrS5oTEQslbQ80rSDdWIHn7h1fHvB42wJ32jmXblp1hNK+via9N7U3TG76fl1LezyR3lznA/YfcbeXnlXnXXuLGtwPA78gK87wauCjeY3cGcC8kS40M6tCnW+aFc3D/RXw/IZT10q6ADgiulUDzcyshDqvNBtNLYUXAz+T5FoKZlY7yfZwaUMtBTOzbqpzg6uRwkkaB7wXeDnwoYi4WdJ9EbFttwKORNK8/AZdElLLC+llTi0vOPO6ZMQG9+kfkmYDpwIPk43fPq/TwVohaWFEzKk6R6tSywvpZU4tLzjzusS1FMzMuqTjtRTMzCxTtNKs7lIbQ0otL6SXObW84MzrjJbGcM3MbOxS7+GamSXDDa6ZWZe4wTUz6xI3uB0kaYN88QiSdpR0hKRal52S9N5WzplZeUk0uJKWSnpymGOppDrPC74amCxpC+D/gH8hK+peZ8cOc+64bocoQ9KXJU2XNFHSryQ9KunNVecaiaRftXKuLiR9esj34yWdXVWeVCWx6VFETKs6wygpIlZIehvwzYj4sqSbqw41HElvAN4IbDOkaNE04O/VpGrZYRHxYUmvAf4MHEn2ZvfDSlMNQ9JkYH1gI0kbktUnAZgObFFZsGJbSvpoRHxB0iTgf4HfVR0qNUk0uENJ2gSYPPh9RPy1wjgjkaT9gTcBb8vPNS3cXrFfA4uAjVi7ONFS4NZKErVucJjmFcCPI2KJpJF+vkr/BrwP2By4kWca3CeBb1QVqgVvBc6W9FHgJcBFEfG1ijMlJ6l5uJKOIGsMNgceAbYC7oqIXSsN1oSkg4APAtdFxJckbQu8LyLeU3G0niLpi2QF8lcC+wEzgQsi4oWVBhuBpBMi4utV5ygiae+GbycCpwPXAd8DiIibqsiVqtQa3FuAg4HLImIvSS8B3hwRbyu41Fok6UjgS8AmZL0vARER0ysNNoL8I+4GwJKI6Je0ATA1Ih6uOFpTkj4DfDIi+vPvpwOnRcTx1SZbm6QrRng4IuLgroXpAakNKayJiMckjZM0LiKukFS7jzWSvhYR75P0C3h2+fmaF27/MvDPEXFX1UFK+E1EPN0Ti4jlkq4B9h7hmqqNB26QdDwwi2w4oXY93oh4SdUZeklqDe4TkqaS3RA5W9IjwPKKMw3nB/n/f6XSFKPzcCqNraRNyW40TZG0F2vfgFq/smAtiIj/yGcl/JZsl9UDI6K2u4FKmgV8Htg8Ig6XtAuwf76Dt7UotSGFDYCnyP5hvYlsM8uzI+KxSoP1EEmnAZsCPwNWDZ6PiJ9WFqoJSceSTVmbAyxg7RtQ/1XHzIMkHQh8i2wmxe7AhsDbIqKWW+1Kuhg4EzgpIl4gaQLwu4jYveJoSUmqwU2NpAOAT5Ld3JvAM+OhtdgxYziSzhzmdETEW+TAtHoAAAdVSURBVLsepgX5wpI3RERSc0Il3QAcFxF35t8fCXw+InaqNtnwJC2IiH0l/S4i9srP3RwRe1adLSVJDSkkeEPne8D7yab/9FecpSV1u2lTJCIGJL0fSKrBJfs4/vTvRET8VNJVVQYqsFzSc8nvSUh6EbCk2kjpSWKlWYMvk23xMyMipkfEtBo3tpDdNb84Ih6JiMcGj6pDjSRfgvwrSbfn3+8h6WNV5ypwmaR/l7SlpOcMHlWHKrDd0L9n4J0VZxrJB4Cfk+W+Dvg+cEK1kdKT1JCCpOsi4oCqc7Qqnx86Hvgpa4+H1nbuYt7L+hBwesNHx9sjYrdqkzUn6U/DnK770E2Kf88TgOeTfbL8fUSsqThScpIaUgAWSvoRCdzQyQ1OvG/cbC/I5hLX1foRccOQlVp9VYVpRURsU3WGUUjq71nS+mS93K0i4u2SdpD0/Ii4oOpsKUmtwZ0OrAAOazgXZD3I2kl0DuNiSdvxzFjdUWRLfmtN0m7ALqy95Pv71SUqlNrf85lk9yL2z79/EPgx4Aa3hKSGFFKTr4B6LbA1DW9uEfHpZtdULV9+PB/4B7L5oX8iW8335ypzjUTSJ4AXkzW4FwGHA9dGxFFV5hpJk7/nN0XEXyoN1oTybdGHzFK4JSJeUHW2lCTRw5X04bzS1tcZfuVWXWsTnE92J/dGGoZA6iwi7gMOzec8j4uIpVVnasFRwAvI5oUen0/Sr12lsCEeJOs1XgE8h2zu8LFAXd+MV0uawjM98u1I5He6TpJocIGPkM1QuJesN5CK2RExt+oQZUiaCbyFvFc+OMZY4zc1gJX59LC+vCbBI8CWVYcqcD7wBHATUMvFDkN8AriErEzj2cAB1LxOch2l0uA+LGlz4Hiyj461rb03xK8l7R4Rt1UdpISLgOuB24CBirO0amH+RvEdsk8Ty4DfVBupUGpvxscCFwI/Ae4D3hsRi6uNlJ4kxnAlnQC8C9iW7KPY0w9R4+k/ku4Eticbn1vFM3n3qDTYCCTd1FgIJjWStgamR0Sta/hKmg98PZU347wy3z/lx3ZkxcevjojTKg2WmCQa3EGSvhURdZ4cvhZJWw13vq43RgDyVVvLyO4+N069q+2uD5J+FRGHFJ2rk0TfjMcD+5IVIH8H2VBOLZci11UqQwoApNTYQtawSvpHYIeIOFPSxsDUqnMVWA2cDJzEMzcog+zTRa0kvF0NZDMpkpFXNtuAbKjmGmDfiHik2lTpSarBTU0+XWkO2eqcM8kq5v+Q7IZDXX0Q2D6R8bnhtqsJsm2BaldbtlGdP+U0cSuwD7Ab2cybJyT9JiJWVhsrLanVUkjNa4AjyGv25qX36r4h5j1ki0tqLyJOy1eZfQ7YM//6TLKbOnW/aZaUiHh/RBxItkHnY2R/z09Umyo97uF21uqICEmDcxc3qDpQC5YDN+dbqzSO4dZ5WthREfHpfPjmYLLC79/imaXVNkaS3k12w2wfsp2RzyAbWrAS3OB21v9KOh2YKentZDuffqfiTEV+lh8pGSxz+ArgOxFxoaTPVhmoB00GvgrcGBG1rflQd0nNUkiNpC8Bl5HVfhDwS+DQiPhIpcF6jKQLyKYLvpRsH7OVwA1edmp14wa3g4ab0yrp1ppP/dkB+ALPLgRTu1kKg/JKVnOB2yLij5I2A3aPiP+rOJrZWjyk0AGS3km+UENS4wT8acB11aRq2ZlkyzhPJZtveTw1v7kaEStoqBgXEYuod+UtW0e5h9sBkmaQbQr4BeDEhoeW1nkBAYCkGyNiH0m3DW4QOHiu6mxmqXMPtwMiYgnZXMU3VJ1lFFblGzP+Mb8z/SD1X6xhlgT3cG0tkvYF7gJmAp8h24r+yxFxfaXBzHqAG1wzsy7xkIKtRdKOZJsbbsXau1TUeR82syS4h2trkXQL8G2y2gSDCwqIiBsrC2XWI9zg2lo8I8Gsc9zg2lokfZJsi5rzSKQerlkq3ODaWiT9aZjTtd1VwywlbnDtafn826Mj4kdVZzHrRW5wbS2SFkbEnKpzmPUiN7i2FklfBBYDPyIvnA4ewzVrBze4thaP4Zp1jhtcM7Mu8UozW4uktwx3PiK+3+0sZr3GDa4NtW/D15OBQ4CbADe4ZmPkIQUbkaSZwDkRMbfqLGapq3Ulf6uF5cA2VYcw6wUeUrC1SPoFMPixZzywM/C/1SUy6x0eUrC1SDqo4ds+4C8R8UBVecx6iYcUbC0RcRVwN9mGlxsCq6tNZNY73ODaWiS9DrgBOBp4HfBbSUdVm8qsN3hIwdaSFyB/aUQ8kn+/MXBZRLyg2mRm6XMP14YaN9jY5h7DvydmbeFZCjbUxZJ+CfxP/v3rgYsqzGPWM9xzsaECOB3YIz/mVxvHrHd4DNfWIummiNh7yLlbI2KPqjKZ9QoPKRgAkt4JvAvYVtKtDQ9NA66rJpVZb3EP1wCQNINs3u0XgBMbHlrq4uNm7eEG18ysS3zTzMysS9zgmpl1iRtcM7MucYNrZtYlbnDNzLrk/wPV8+P9xkEjNgAAAABJRU5ErkJggg==\n","text/plain":["<Figure size 432x288 with 2 Axes>"]},"metadata":{"tags":[],"needs_background":"light"}}]},{"cell_type":"markdown","metadata":{"id":"Jh_EP08yezUw"},"source":["1D global max-pooling would extract the highest value from each of our `num_filters` for each `filter_size`. We could also follow this same approach to figure out which n-gram is most relevant but notice in the heatmap above that many filters don't have much variance. To mitigate this, this [paper](https://www.aclweb.org/anthology/W18-5408/) uses threshold values to determine which filters to use for interpretability. But to keep things simple, let's extract which tokens' filter outputs were extracted via max-pooling the most frequenctly. "]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"GOT3TkRgexTI","executionInfo":{"status":"ok","timestamp":1608255727355,"user_tz":420,"elapsed":527,"user":{"displayName":"Goku Mohandas","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GjMIOf3R_zwS_zZx4ZyPMtQe0lOkGpPOEUEKWpM7g=s64","userId":"00378334517810298963"}},"outputId":"3f0223bb-ec44-4113-d1dc-d440ade4b400"},"source":["sample_index = 0\n","print (f\"Origin text:\\n{text}\")\n","print (f\"\\nPreprocessed text:\\n{tokenizer.sequences_to_texts(X)[0]}\")\n","print (\"\\nMost important n-grams:\")\n","# Process conv outputs for each unique filter size\n","for i, filter_size in enumerate(FILTER_SIZES):\n","\n","    # Identify most important n-gram (excluding last token)\n","    popular_indices = collections.Counter([np.argmax(conv_output) \\\n","            for conv_output in conv_outputs[i]])\n","    \n","    # Get corresponding text\n","    start = popular_indices.most_common(1)[-1][0]\n","    n_gram = \" \".join([token for token in tokens[start:start+filter_size]])\n","    print (f\"[{filter_size}-gram]: {n_gram}\")"],"execution_count":null,"outputs":[{"output_type":"stream","text":["Origin text:\n","The final tennis tournament starts next week.\n","\n","Preprocessed text:\n","final tennis tournament starts next week\n","\n","Most important n-grams:\n","[1-gram]: tennis\n","[2-gram]: tennis tournament\n","[3-gram]: final tennis tournament\n"],"name":"stdout"}]}]}